feat: documents matrix methods
This commit is contained in:
@@ -23,7 +23,7 @@
|
||||
//! use smartcore::linear::linear_regression::*;
|
||||
//!
|
||||
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
|
||||
//! let x = DenseMatrix::from_array(&[
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
@@ -125,7 +125,7 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
panic!("Number of rows of X doesn't match number of rows of Y");
|
||||
}
|
||||
|
||||
let a = x.v_stack(&M::ones(x_nrows, 1));
|
||||
let a = x.h_stack(&M::ones(x_nrows, 1));
|
||||
|
||||
let w = match parameters.solver {
|
||||
LinearRegressionSolverName::QR => a.qr_solve_mut(b),
|
||||
@@ -145,7 +145,7 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> M::RowVector {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.dot(&self.coefficients);
|
||||
let mut y_hat = x.matmul(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
y_hat.transpose().to_row_vector()
|
||||
}
|
||||
@@ -168,7 +168,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn ols_fit_predict() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
@@ -215,7 +215,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
|
||||
Reference in New Issue
Block a user