feat: puts ndarray and nalgebra bindings behind feature flags
This commit is contained in:
@@ -15,10 +15,10 @@ jobs:
|
|||||||
command: cargo fmt -- --check
|
command: cargo fmt -- --check
|
||||||
- run:
|
- run:
|
||||||
name: Stable Build
|
name: Stable Build
|
||||||
command: cargo build
|
command: cargo build --features "nalgebra-bindings ndarray-bindings"
|
||||||
- run:
|
- run:
|
||||||
name: Test
|
name: Test
|
||||||
command: cargo test
|
command: cargo test --features "nalgebra-bindings ndarray-bindings"
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: project-cache
|
key: project-cache
|
||||||
paths:
|
paths:
|
||||||
|
|||||||
+7
-2
@@ -4,9 +4,14 @@ version = "0.1.0"
|
|||||||
authors = ["SmartCore Developers"]
|
authors = ["SmartCore Developers"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
ndarray-bindings = ["ndarray"]
|
||||||
|
nalgebra-bindings = ["nalgebra"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ndarray = "0.13"
|
ndarray = { version = "0.13", optional = true }
|
||||||
nalgebra = "0.22.0"
|
nalgebra = { version = "0.22.0", optional = true }
|
||||||
num-traits = "0.2.12"
|
num-traits = "0.2.12"
|
||||||
num = "0.3.0"
|
num = "0.3.0"
|
||||||
rand = "0.7.3"
|
rand = "0.7.3"
|
||||||
|
|||||||
@@ -124,7 +124,6 @@ impl<T: FloatExt> RandomForestRegressor<T> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
use ndarray::{arr1, arr2};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_longley() {
|
fn fit_longley() {
|
||||||
@@ -173,53 +172,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn my_fit_longley_ndarray() {
|
|
||||||
let x = arr2(&[
|
|
||||||
[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
|
||||||
[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
|
||||||
[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
|
||||||
[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
|
||||||
[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
|
||||||
[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
|
||||||
[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
|
||||||
[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
|
||||||
[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
|
||||||
[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
|
||||||
[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
|
||||||
[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
|
||||||
[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
|
||||||
[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
|
||||||
[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
|
||||||
[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
|
||||||
]);
|
|
||||||
let y = arr1(&[
|
|
||||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
|
||||||
114.2, 115.7, 116.9,
|
|
||||||
]);
|
|
||||||
|
|
||||||
let expected_y: Vec<f64> = vec![
|
|
||||||
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
|
|
||||||
];
|
|
||||||
|
|
||||||
let y_hat = RandomForestRegressor::fit(
|
|
||||||
&x,
|
|
||||||
&y,
|
|
||||||
RandomForestRegressorParameters {
|
|
||||||
max_depth: None,
|
|
||||||
min_samples_leaf: 1,
|
|
||||||
min_samples_split: 2,
|
|
||||||
n_trees: 1000,
|
|
||||||
mtry: Option::None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.predict(&x);
|
|
||||||
|
|
||||||
for i in 0..y_hat.len() {
|
|
||||||
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
pub mod evd;
|
pub mod evd;
|
||||||
pub mod lu;
|
pub mod lu;
|
||||||
pub mod naive;
|
pub mod naive;
|
||||||
|
#[cfg(feature = "nalgebra-bindings")]
|
||||||
pub mod nalgebra_bindings;
|
pub mod nalgebra_bindings;
|
||||||
|
#[cfg(feature = "ndarray-bindings")]
|
||||||
pub mod ndarray_bindings;
|
pub mod ndarray_bindings;
|
||||||
pub mod qr;
|
pub mod qr;
|
||||||
pub mod svd;
|
pub mod svd;
|
||||||
|
|||||||
@@ -368,6 +368,7 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::linear::linear_regression::*;
|
||||||
use nalgebra::{DMatrix, Matrix2x3, RowDVector};
|
use nalgebra::{DMatrix, Matrix2x3, RowDVector};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -674,4 +675,49 @@ mod tests {
|
|||||||
assert_eq!(res.len(), 7);
|
assert_eq!(res.len(), 7);
|
||||||
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ols_fit_predict() {
|
||||||
|
let x = DMatrix::from_row_slice(
|
||||||
|
16,
|
||||||
|
6,
|
||||||
|
&[
|
||||||
|
234.289, 235.6, 159.0, 107.608, 1947., 60.323, 259.426, 232.5, 145.6, 108.632,
|
||||||
|
1948., 61.122, 258.054, 368.2, 161.6, 109.773, 1949., 60.171, 284.599, 335.1,
|
||||||
|
165.0, 110.929, 1950., 61.187, 328.975, 209.9, 309.9, 112.075, 1951., 63.221,
|
||||||
|
346.999, 193.2, 359.4, 113.270, 1952., 63.639, 365.385, 187.0, 354.7, 115.094,
|
||||||
|
1953., 64.989, 363.112, 357.8, 335.0, 116.219, 1954., 63.761, 397.469, 290.4,
|
||||||
|
304.8, 117.388, 1955., 66.019, 419.180, 282.2, 285.7, 118.734, 1956., 67.857,
|
||||||
|
442.769, 293.6, 279.8, 120.445, 1957., 68.169, 444.546, 468.1, 263.7, 121.950,
|
||||||
|
1958., 66.513, 482.704, 381.3, 255.2, 123.366, 1959., 68.655, 502.601, 393.1,
|
||||||
|
251.4, 125.368, 1960., 69.564, 518.173, 480.6, 257.2, 127.852, 1961., 69.331,
|
||||||
|
554.894, 400.7, 282.7, 130.081, 1962., 70.551,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
|
let y: RowDVector<f64> = RowDVector::from_vec(vec![
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
]);
|
||||||
|
|
||||||
|
let y_hat_qr = LinearRegression::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
LinearRegressionParameters {
|
||||||
|
solver: LinearRegressionSolverName::QR,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.predict(&x);
|
||||||
|
|
||||||
|
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
||||||
|
|
||||||
|
assert!(y
|
||||||
|
.iter()
|
||||||
|
.zip(y_hat_qr.iter())
|
||||||
|
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||||
|
assert!(y
|
||||||
|
.iter()
|
||||||
|
.zip(y_hat_svd.iter())
|
||||||
|
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -336,7 +336,9 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use ndarray::{arr1, arr2, Array2};
|
use crate::ensemble::random_forest_regressor::*;
|
||||||
|
use crate::linear::logistic_regression::*;
|
||||||
|
use ndarray::{arr1, arr2, Array1, Array2};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_get_set() {
|
fn vec_get_set() {
|
||||||
@@ -654,4 +656,92 @@ mod tests {
|
|||||||
a.abs_mut();
|
a.abs_mut();
|
||||||
assert_eq!(a, expected);
|
assert_eq!(a, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lr_fit_predict_iris() {
|
||||||
|
let x = arr2(&[
|
||||||
|
[5.1, 3.5, 1.4, 0.2],
|
||||||
|
[4.9, 3.0, 1.4, 0.2],
|
||||||
|
[4.7, 3.2, 1.3, 0.2],
|
||||||
|
[4.6, 3.1, 1.5, 0.2],
|
||||||
|
[5.0, 3.6, 1.4, 0.2],
|
||||||
|
[5.4, 3.9, 1.7, 0.4],
|
||||||
|
[4.6, 3.4, 1.4, 0.3],
|
||||||
|
[5.0, 3.4, 1.5, 0.2],
|
||||||
|
[4.4, 2.9, 1.4, 0.2],
|
||||||
|
[4.9, 3.1, 1.5, 0.1],
|
||||||
|
[7.0, 3.2, 4.7, 1.4],
|
||||||
|
[6.4, 3.2, 4.5, 1.5],
|
||||||
|
[6.9, 3.1, 4.9, 1.5],
|
||||||
|
[5.5, 2.3, 4.0, 1.3],
|
||||||
|
[6.5, 2.8, 4.6, 1.5],
|
||||||
|
[5.7, 2.8, 4.5, 1.3],
|
||||||
|
[6.3, 3.3, 4.7, 1.6],
|
||||||
|
[4.9, 2.4, 3.3, 1.0],
|
||||||
|
[6.6, 2.9, 4.6, 1.3],
|
||||||
|
[5.2, 2.7, 3.9, 1.4],
|
||||||
|
]);
|
||||||
|
let y: Array1<f64> = arr1(&[
|
||||||
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
|
]);
|
||||||
|
|
||||||
|
let lr = LogisticRegression::fit(&x, &y);
|
||||||
|
|
||||||
|
let y_hat = lr.predict(&x);
|
||||||
|
|
||||||
|
let error: f64 = y
|
||||||
|
.into_iter()
|
||||||
|
.zip(y_hat.into_iter())
|
||||||
|
.map(|(&a, &b)| (a - b).abs())
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
assert!(error <= 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn my_fit_longley_ndarray() {
|
||||||
|
let x = arr2(&[
|
||||||
|
[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
|
[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
|
[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||||
|
[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||||
|
[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||||
|
[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||||
|
[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||||
|
[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||||
|
[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||||
|
[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||||
|
[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||||
|
[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||||
|
[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
|
[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
|
[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
|
[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
|
]);
|
||||||
|
let y = arr1(&[
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
]);
|
||||||
|
|
||||||
|
let expected_y: Vec<f64> = vec![
|
||||||
|
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
|
||||||
|
];
|
||||||
|
|
||||||
|
let y_hat = RandomForestRegressor::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
RandomForestRegressorParameters {
|
||||||
|
max_depth: None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2,
|
||||||
|
n_trees: 1000,
|
||||||
|
mtry: Option::None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.predict(&x);
|
||||||
|
|
||||||
|
for i in 0..y_hat.len() {
|
||||||
|
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ pub enum LinearRegressionSolverName {
|
|||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct LinearRegressionParameters {
|
pub struct LinearRegressionParameters {
|
||||||
solver: LinearRegressionSolverName,
|
pub solver: LinearRegressionSolverName,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
@@ -81,55 +81,9 @@ impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use nalgebra::{DMatrix, RowDVector};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ols_fit_predict() {
|
fn ols_fit_predict() {
|
||||||
let x = DMatrix::from_row_slice(
|
|
||||||
16,
|
|
||||||
6,
|
|
||||||
&[
|
|
||||||
234.289, 235.6, 159.0, 107.608, 1947., 60.323, 259.426, 232.5, 145.6, 108.632,
|
|
||||||
1948., 61.122, 258.054, 368.2, 161.6, 109.773, 1949., 60.171, 284.599, 335.1,
|
|
||||||
165.0, 110.929, 1950., 61.187, 328.975, 209.9, 309.9, 112.075, 1951., 63.221,
|
|
||||||
346.999, 193.2, 359.4, 113.270, 1952., 63.639, 365.385, 187.0, 354.7, 115.094,
|
|
||||||
1953., 64.989, 363.112, 357.8, 335.0, 116.219, 1954., 63.761, 397.469, 290.4,
|
|
||||||
304.8, 117.388, 1955., 66.019, 419.180, 282.2, 285.7, 118.734, 1956., 67.857,
|
|
||||||
442.769, 293.6, 279.8, 120.445, 1957., 68.169, 444.546, 468.1, 263.7, 121.950,
|
|
||||||
1958., 66.513, 482.704, 381.3, 255.2, 123.366, 1959., 68.655, 502.601, 393.1,
|
|
||||||
251.4, 125.368, 1960., 69.564, 518.173, 480.6, 257.2, 127.852, 1961., 69.331,
|
|
||||||
554.894, 400.7, 282.7, 130.081, 1962., 70.551,
|
|
||||||
],
|
|
||||||
);
|
|
||||||
|
|
||||||
let y: RowDVector<f64> = RowDVector::from_vec(vec![
|
|
||||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
|
||||||
114.2, 115.7, 116.9,
|
|
||||||
]);
|
|
||||||
|
|
||||||
let y_hat_qr = LinearRegression::fit(
|
|
||||||
&x,
|
|
||||||
&y,
|
|
||||||
LinearRegressionParameters {
|
|
||||||
solver: LinearRegressionSolverName::QR,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.predict(&x);
|
|
||||||
|
|
||||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
|
||||||
|
|
||||||
assert!(y
|
|
||||||
.iter()
|
|
||||||
.zip(y_hat_qr.iter())
|
|
||||||
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
|
||||||
assert!(y
|
|
||||||
.iter()
|
|
||||||
.zip(y_hat_svd.iter())
|
|
||||||
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn ols_fit_predict_nalgebra() {
|
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
|
|||||||
@@ -265,7 +265,6 @@ impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use ndarray::{arr1, arr2, Array1};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn multiclass_objective_f() {
|
fn multiclass_objective_f() {
|
||||||
@@ -426,31 +425,31 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lr_fit_predict_iris() {
|
fn lr_fit_predict_iris() {
|
||||||
let x = arr2(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
[4.9, 3.0, 1.4, 0.2],
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
[4.7, 3.2, 1.3, 0.2],
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
[4.6, 3.1, 1.5, 0.2],
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
[5.0, 3.6, 1.4, 0.2],
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
[5.4, 3.9, 1.7, 0.4],
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
[4.6, 3.4, 1.4, 0.3],
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
[5.0, 3.4, 1.5, 0.2],
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
[4.4, 2.9, 1.4, 0.2],
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
[4.9, 3.1, 1.5, 0.1],
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
[7.0, 3.2, 4.7, 1.4],
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
[6.4, 3.2, 4.5, 1.5],
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
[6.9, 3.1, 4.9, 1.5],
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
[5.5, 2.3, 4.0, 1.3],
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
[6.5, 2.8, 4.6, 1.5],
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
[5.7, 2.8, 4.5, 1.3],
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
[6.3, 3.3, 4.7, 1.6],
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
[4.9, 2.4, 3.3, 1.0],
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
[6.6, 2.9, 4.6, 1.3],
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
[5.2, 2.7, 3.9, 1.4],
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
]);
|
]);
|
||||||
let y: Array1<f64> = arr1(&[
|
let y: Vec<f64> = vec![
|
||||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
]);
|
];
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y);
|
||||||
|
|
||||||
@@ -459,7 +458,7 @@ mod tests {
|
|||||||
let error: f64 = y
|
let error: f64 = y
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(y_hat.into_iter())
|
.zip(y_hat.into_iter())
|
||||||
.map(|(&a, &b)| (a - b).abs())
|
.map(|(a, b)| (a - b).abs())
|
||||||
.sum();
|
.sum();
|
||||||
|
|
||||||
assert!(error <= 1.0);
|
assert!(error <= 1.0);
|
||||||
|
|||||||
Reference in New Issue
Block a user