feat: refactoring, adds Result to most public API

This commit is contained in:
Volodymyr Orlov
2020-09-18 15:20:32 -07:00
parent 4921ae76f5
commit a9db970195
24 changed files with 389 additions and 298 deletions
+16 -12
View File
@@ -16,7 +16,7 @@
//! //!
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points //! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
//! //!
//! let mut tree = CoverTree::new(data, SimpleDistance {}); //! let mut tree = CoverTree::new(data, SimpleDistance {}).unwrap();
//! //!
//! tree.find(&5, 3); // find 3 knn points from 5 //! tree.find(&5, 3); // find 3 knn points from 5
//! //!
@@ -26,6 +26,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::sort::heap_select::HeapSelection; use crate::algorithm::sort::heap_select::HeapSelection;
use crate::error::{Failed, FailedError};
use crate::math::distance::Distance; use crate::math::distance::Distance;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -73,7 +74,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
/// Construct a cover tree. /// Construct a cover tree.
/// * `data` - vector of data points to search for. /// * `data` - vector of data points to search for.
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface. /// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
pub fn new(data: Vec<T>, distance: D) -> CoverTree<T, F, D> { pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, F, D>, Failed> {
let base = F::from_f64(1.3).unwrap(); let base = F::from_f64(1.3).unwrap();
let root = Node { let root = Node {
idx: 0, idx: 0,
@@ -93,19 +94,22 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
tree.build_cover_tree(); tree.build_cover_tree();
tree Ok(tree)
} }
/// Find k nearest neighbors of `p` /// Find k nearest neighbors of `p`
/// * `p` - look for k nearest points to `p` /// * `p` - look for k nearest points to `p`
/// * `k` - the number of nearest neighbors to return /// * `k` - the number of nearest neighbors to return
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> { pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
if k <= 0 { if k <= 0 {
panic!("k should be > 0"); return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
} }
if k > self.data.len() { if k > self.data.len() {
panic!("k is > than the dataset size"); return Err(Failed::because(
FailedError::FindFailed,
"k is > than the dataset size",
));
} }
let e = self.get_data_value(self.root.idx); let e = self.get_data_value(self.root.idx);
@@ -171,7 +175,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
} }
} }
neighbors.into_iter().take(k).collect() Ok(neighbors.into_iter().take(k).collect())
} }
fn new_leaf(&self, idx: usize) -> Node<F> { fn new_leaf(&self, idx: usize) -> Node<F> {
@@ -407,9 +411,9 @@ mod tests {
fn cover_tree_test() { fn cover_tree_test() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let tree = CoverTree::new(data, SimpleDistance {}); let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
let mut knn = tree.find(&5, 3); let mut knn = tree.find(&5, 3).unwrap();
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect(); let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
assert_eq!(vec!(3, 4, 5), knn); assert_eq!(vec!(3, 4, 5), knn);
@@ -425,9 +429,9 @@ mod tests {
vec![9., 10.], vec![9., 10.],
]; ];
let tree = CoverTree::new(data, Distances::euclidian()); let tree = CoverTree::new(data, Distances::euclidian()).unwrap();
let mut knn = tree.find(&vec![1., 2.], 3); let mut knn = tree.find(&vec![1., 2.], 3).unwrap();
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect(); let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
@@ -438,7 +442,7 @@ mod tests {
fn serde() { fn serde() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let tree = CoverTree::new(data, SimpleDistance {}); let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> = let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap(); serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();
+18 -10
View File
@@ -15,7 +15,7 @@
//! //!
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points //! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
//! //!
//! let knn = LinearKNNSearch::new(data, SimpleDistance {}); //! let knn = LinearKNNSearch::new(data, SimpleDistance {}).unwrap();
//! //!
//! knn.find(&5, 3); // find 3 knn points from 5 //! knn.find(&5, 3); // find 3 knn points from 5
//! //!
@@ -26,6 +26,7 @@ use std::cmp::{Ordering, PartialOrd};
use std::marker::PhantomData; use std::marker::PhantomData;
use crate::algorithm::sort::heap_select::HeapSelection; use crate::algorithm::sort::heap_select::HeapSelection;
use crate::error::Failed;
use crate::math::distance::Distance; use crate::math::distance::Distance;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -41,18 +42,18 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
/// Initializes algorithm. /// Initializes algorithm.
/// * `data` - vector of data points to search for. /// * `data` - vector of data points to search for.
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface. /// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D> { pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, F, D>, Failed> {
LinearKNNSearch { Ok(LinearKNNSearch {
data: data, data: data,
distance: distance, distance: distance,
f: PhantomData, f: PhantomData,
} })
} }
/// Find k nearest neighbors /// Find k nearest neighbors
/// * `from` - look for k nearest points to `from` /// * `from` - look for k nearest points to `from`
/// * `k` - the number of nearest neighbors to return /// * `k` - the number of nearest neighbors to return
pub fn find(&self, from: &T, k: usize) -> Vec<(usize, F)> { pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
if k < 1 || k > self.data.len() { if k < 1 || k > self.data.len() {
panic!("k should be >= 1 and <= length(data)"); panic!("k should be >= 1 and <= length(data)");
} }
@@ -76,10 +77,11 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
} }
} }
heap.get() Ok(heap
.get()
.into_iter() .into_iter()
.flat_map(|x| x.index.map(|i| (i, x.distance))) .flat_map(|x| x.index.map(|i| (i, x.distance)))
.collect() .collect())
} }
} }
@@ -120,9 +122,14 @@ mod tests {
fn knn_find() { fn knn_find() {
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {}); let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {}).unwrap();
let mut found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect(); let mut found_idxs1: Vec<usize> = algorithm1
.find(&2, 3)
.unwrap()
.iter()
.map(|v| v.0)
.collect();
found_idxs1.sort(); found_idxs1.sort();
assert_eq!(vec!(0, 1, 2), found_idxs1); assert_eq!(vec!(0, 1, 2), found_idxs1);
@@ -135,10 +142,11 @@ mod tests {
vec![5., 5.], vec![5., 5.],
]; ];
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian()); let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian()).unwrap();
let mut found_idxs2: Vec<usize> = algorithm2 let mut found_idxs2: Vec<usize> = algorithm2
.find(&vec![3., 3.], 3) .find(&vec![3., 3.], 3)
.unwrap()
.iter() .iter()
.map(|v| v.0) .map(|v| v.0)
.collect(); .collect();
+7 -12
View File
@@ -61,7 +61,7 @@ use std::iter::Sum;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::bbd_tree::BBDTree; use crate::algorithm::neighbour::bbd_tree::BBDTree;
use crate::error::{FitFailedError, PredictFailedError}; use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::distance::euclidian::*; use crate::math::distance::euclidian::*;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -122,19 +122,16 @@ impl<T: RealNumber + Sum> KMeans<T> {
data: &M, data: &M,
k: usize, k: usize,
parameters: KMeansParameters, parameters: KMeansParameters,
) -> Result<KMeans<T>, FitFailedError> { ) -> Result<KMeans<T>, Failed> {
let bbd = BBDTree::new(data); let bbd = BBDTree::new(data);
if k < 2 { if k < 2 {
return Err(FitFailedError::new(&format!( return Err(Failed::fit(&format!("invalid number of clusters: {}", k)));
"Invalid number of clusters: {}",
k
)));
} }
if parameters.max_iter <= 0 { if parameters.max_iter <= 0 {
return Err(FitFailedError::new(&format!( return Err(Failed::fit(&format!(
"Invalid maximum number of iterations: {}", "invalid maximum number of iterations: {}",
parameters.max_iter parameters.max_iter
))); )));
} }
@@ -191,7 +188,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
/// Predict clusters for `x` /// Predict clusters for `x`
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features. /// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, PredictFailedError> { pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape(); let (n, _) = x.shape();
let mut result = M::zeros(1, n); let mut result = M::zeros(1, n);
@@ -274,11 +271,9 @@ mod tests {
fn invalid_k() { fn invalid_k() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
println!("{:?}", KMeans::fit(&x, 0, Default::default()));
assert!(KMeans::fit(&x, 0, Default::default()).is_err()); assert!(KMeans::fit(&x, 0, Default::default()).is_err());
assert_eq!( assert_eq!(
"Invalid number of clusters: 1", "Fit failed: invalid number of clusters: 1",
KMeans::fit(&x, 1, Default::default()) KMeans::fit(&x, 1, Default::default())
.unwrap_err() .unwrap_err()
.to_string() .to_string()
+24 -18
View File
@@ -37,9 +37,9 @@
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]); //! ]);
//! //!
//! let pca = PCA::new(&iris, 2, Default::default()); // Reduce number of features to 2 //! let pca = PCA::fit(&iris, 2, Default::default()).unwrap(); // Reduce number of features to 2
//! //!
//! let iris_reduced = pca.transform(&iris); //! let iris_reduced = pca.transform(&iris).unwrap();
//! //!
//! ``` //! ```
//! //!
@@ -49,6 +49,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -100,7 +101,11 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation. /// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `n_components` - number of components to keep. /// * `n_components` - number of components to keep.
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values. /// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA<T, M> { pub fn fit(
data: &M,
n_components: usize,
parameters: PCAParameters,
) -> Result<PCA<T, M>, Failed> {
let (m, n) = data.shape(); let (m, n) = data.shape();
let mu = data.column_mean(); let mu = data.column_mean();
@@ -117,7 +122,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
let mut eigenvectors; let mut eigenvectors;
if m > n && !parameters.use_correlation_matrix { if m > n && !parameters.use_correlation_matrix {
let svd = x.svd(); let svd = x.svd()?;
eigenvalues = svd.s; eigenvalues = svd.s;
for i in 0..eigenvalues.len() { for i in 0..eigenvalues.len() {
eigenvalues[i] = eigenvalues[i] * eigenvalues[i]; eigenvalues[i] = eigenvalues[i] * eigenvalues[i];
@@ -155,7 +160,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
} }
} }
let evd = cov.evd(true); let evd = cov.evd(true)?;
eigenvalues = evd.d; eigenvalues = evd.d;
@@ -167,7 +172,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
} }
} }
} else { } else {
let evd = cov.evd(true); let evd = cov.evd(true)?;
eigenvalues = evd.d; eigenvalues = evd.d;
@@ -189,26 +194,26 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
} }
} }
PCA { Ok(PCA {
eigenvectors: eigenvectors, eigenvectors: eigenvectors,
eigenvalues: eigenvalues, eigenvalues: eigenvalues,
projection: projection.transpose(), projection: projection.transpose(),
mu: mu, mu: mu,
pmu: pmu, pmu: pmu,
} })
} }
/// Run dimensionality reduction for `x` /// Run dimensionality reduction for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn transform(&self, x: &M) -> M { pub fn transform(&self, x: &M) -> Result<M, Failed> {
let (nrows, ncols) = x.shape(); let (nrows, ncols) = x.shape();
let (_, n_components) = self.projection.shape(); let (_, n_components) = self.projection.shape();
if ncols != self.mu.len() { if ncols != self.mu.len() {
panic!( return Err(Failed::transform(&format!(
"Invalid input vector size: {}, expected: {}", "Invalid input vector size: {}, expected: {}",
ncols, ncols,
self.mu.len() self.mu.len()
); )));
} }
let mut x_transformed = x.matmul(&self.projection); let mut x_transformed = x.matmul(&self.projection);
@@ -217,7 +222,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
x_transformed.sub_element_mut(r, c, self.pmu[c]); x_transformed.sub_element_mut(r, c, self.pmu[c]);
} }
} }
x_transformed Ok(x_transformed)
} }
} }
@@ -372,7 +377,7 @@ mod tests {
302.04806302399646, 302.04806302399646,
]; ];
let pca = PCA::new(&us_arrests, 4, Default::default()); let pca = PCA::fit(&us_arrests, 4, Default::default()).unwrap();
assert!(pca assert!(pca
.eigenvectors .eigenvectors
@@ -383,7 +388,7 @@ mod tests {
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8); assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
} }
let us_arrests_t = pca.transform(&us_arrests); let us_arrests_t = pca.transform(&us_arrests).unwrap();
assert!(us_arrests_t assert!(us_arrests_t
.abs() .abs()
@@ -481,13 +486,14 @@ mod tests {
0.1734300877298357, 0.1734300877298357,
]; ];
let pca = PCA::new( let pca = PCA::fit(
&us_arrests, &us_arrests,
4, 4,
PCAParameters { PCAParameters {
use_correlation_matrix: true, use_correlation_matrix: true,
}, },
); )
.unwrap();
assert!(pca assert!(pca
.eigenvectors .eigenvectors
@@ -498,7 +504,7 @@ mod tests {
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8); assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
} }
let us_arrests_t = pca.transform(&us_arrests); let us_arrests_t = pca.transform(&us_arrests).unwrap();
assert!(us_arrests_t assert!(us_arrests_t
.abs() .abs()
@@ -530,7 +536,7 @@ mod tests {
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]); ]);
let pca = PCA::new(&iris, 4, Default::default()); let pca = PCA::fit(&iris, 4, Default::default()).unwrap();
let deserialized_pca: PCA<f64, DenseMatrix<f64>> = let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap(); serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
+13 -11
View File
@@ -39,8 +39,8 @@
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., //! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ]; //! ];
//! //!
//! let classifier = RandomForestClassifier::fit(&x, &y, Default::default()); //! let classifier = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
//! let y_hat = classifier.predict(&x); // use the same data for prediction //! let y_hat = classifier.predict(&x).unwrap(); // use the same data for prediction
//! ``` //! ```
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
@@ -53,6 +53,7 @@ use std::fmt::Debug;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::tree::decision_tree_classifier::{ use crate::tree::decision_tree_classifier::{
@@ -126,7 +127,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
x: &M, x: &M,
y: &M::RowVector, y: &M::RowVector,
parameters: RandomForestClassifierParameters, parameters: RandomForestClassifierParameters,
) -> RandomForestClassifier<T> { ) -> Result<RandomForestClassifier<T>, Failed> {
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
@@ -158,20 +159,20 @@ impl<T: RealNumber> RandomForestClassifier<T> {
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split, min_samples_split: parameters.min_samples_split,
}; };
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params); let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
trees.push(tree); trees.push(tree);
} }
RandomForestClassifier { Ok(RandomForestClassifier {
parameters: parameters, parameters: parameters,
trees: trees, trees: trees,
classes, classes,
} })
} }
/// Predict class for `x` /// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -180,7 +181,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
result.set(0, i, self.classes[self.predict_for_row(x, i)]); result.set(0, i, self.classes[self.predict_for_row(x, i)]);
} }
result.to_row_vector() Ok(result.to_row_vector())
} }
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize { fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
@@ -263,9 +264,10 @@ mod tests {
n_trees: 100, n_trees: 100,
m: Option::None, m: Option::None,
}, },
); )
.unwrap();
assert!(accuracy(&y, &classifier.predict(&x)) >= 0.95); assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
} }
#[test] #[test]
@@ -296,7 +298,7 @@ mod tests {
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]; ];
let forest = RandomForestClassifier::fit(&x, &y, Default::default()); let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
let deserialized_forest: RandomForestClassifier<f64> = let deserialized_forest: RandomForestClassifier<f64> =
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap(); bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
+12 -10
View File
@@ -35,9 +35,9 @@
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9 //! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
//! ]; //! ];
//! //!
//! let regressor = RandomForestRegressor::fit(&x, &y, Default::default()); //! let regressor = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap();
//! //!
//! let y_hat = regressor.predict(&x); // use the same data for prediction //! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
//! ``` //! ```
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
@@ -50,6 +50,7 @@ use std::fmt::Debug;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::tree::decision_tree_regressor::{ use crate::tree::decision_tree_regressor::{
@@ -114,7 +115,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
x: &M, x: &M,
y: &M::RowVector, y: &M::RowVector,
parameters: RandomForestRegressorParameters, parameters: RandomForestRegressorParameters,
) -> RandomForestRegressor<T> { ) -> Result<RandomForestRegressor<T>, Failed> {
let (n_rows, num_attributes) = x.shape(); let (n_rows, num_attributes) = x.shape();
let mtry = parameters let mtry = parameters
@@ -130,19 +131,19 @@ impl<T: RealNumber> RandomForestRegressor<T> {
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split, min_samples_split: parameters.min_samples_split,
}; };
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params); let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
trees.push(tree); trees.push(tree);
} }
RandomForestRegressor { Ok(RandomForestRegressor {
parameters: parameters, parameters: parameters,
trees: trees, trees: trees,
} })
} }
/// Predict class for `x` /// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -151,7 +152,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
result.set(0, i, self.predict_for_row(x, i)); result.set(0, i, self.predict_for_row(x, i));
} }
result.to_row_vector() Ok(result.to_row_vector())
} }
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T { fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
@@ -219,7 +220,8 @@ mod tests {
m: Option::None, m: Option::None,
}, },
) )
.predict(&x); .and_then(|rf| rf.predict(&x))
.unwrap();
assert!(mean_absolute_error(&y, &y_hat) < 1.0); assert!(mean_absolute_error(&y, &y_hat) < 1.0);
} }
@@ -249,7 +251,7 @@ mod tests {
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
]; ];
let forest = RandomForestRegressor::fit(&x, &y, Default::default()); let forest = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap();
let deserialized_forest: RandomForestRegressor<f64> = let deserialized_forest: RandomForestRegressor<f64> =
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap(); bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
+81 -39
View File
@@ -2,58 +2,100 @@
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
/// Error to be raised when model does not fits data. use serde::{Deserialize, Serialize};
#[derive(Debug)]
pub struct FitFailedError { /// Generic error to be raised when something goes wrong.
details: String, #[derive(Debug, Serialize, Deserialize)]
pub struct Failed {
err: FailedError,
msg: String,
} }
/// Error to be raised when model prediction cannot be calculated. /// Type of error
#[derive(Debug)] #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct PredictFailedError { pub enum FailedError {
details: String, /// Can't fit algorithm to data
FitFailed = 1,
/// Can't predict new values
PredictFailed,
/// Can't transform data
TransformFailed,
/// Can't find an item
FindFailed,
/// Can't decompose a matrix
DecompositionFailed,
} }
impl FitFailedError { impl Failed {
/// Creates new instance of `FitFailedError` ///get type of error
/// * `msg` - description of the error #[inline]
pub fn new(msg: &str) -> FitFailedError { pub fn error(&self) -> FailedError {
FitFailedError { self.err
details: msg.to_string(), }
/// new instance of `FailedError::FitError`
pub fn fit(msg: &str) -> Self {
Failed {
err: FailedError::FitFailed,
msg: msg.to_string(),
}
}
/// new instance of `FailedError::PredictFailed`
pub fn predict(msg: &str) -> Self {
Failed {
err: FailedError::PredictFailed,
msg: msg.to_string(),
}
}
/// new instance of `FailedError::TransformFailed`
pub fn transform(msg: &str) -> Self {
Failed {
err: FailedError::TransformFailed,
msg: msg.to_string(),
}
}
/// new instance of `err`
pub fn because(err: FailedError, msg: &str) -> Self {
Failed {
err: err,
msg: msg.to_string(),
} }
} }
} }
impl fmt::Display for FitFailedError { impl PartialEq for FailedError {
#[inline(always)]
fn eq(&self, rhs: &Self) -> bool {
*self as u8 == *rhs as u8
}
}
impl PartialEq for Failed {
#[inline(always)]
fn eq(&self, rhs: &Self) -> bool {
self.err == rhs.err && self.msg == rhs.msg
}
}
impl fmt::Display for FailedError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.details) let failed_err_str = match self {
FailedError::FitFailed => "Fit failed",
FailedError::PredictFailed => "Predict failed",
FailedError::TransformFailed => "Transform failed",
FailedError::FindFailed => "Find failed",
FailedError::DecompositionFailed => "Decomposition failed",
};
write!(f, "{}", failed_err_str)
} }
} }
impl Error for FitFailedError { impl fmt::Display for Failed {
fn description(&self) -> &str {
&self.details
}
}
impl PredictFailedError {
/// Creates new instance of `PredictFailedError`
/// * `msg` - description of the error
pub fn new(msg: &str) -> PredictFailedError {
PredictFailedError {
details: msg.to_string(),
}
}
}
impl fmt::Display for PredictFailedError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.details) write!(f, "{}: {}", self.err, self.msg)
} }
} }
impl Error for PredictFailedError { impl Error for Failed {}
fn description(&self) -> &str {
&self.details
}
}
+2 -2
View File
@@ -58,10 +58,10 @@
//! let y = vec![2., 2., 2., 3., 3.]; //! let y = vec![2., 2., 2., 3., 3.];
//! //!
//! // Train classifier //! // Train classifier
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()); //! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
//! //!
//! // Predict classes //! // Predict classes
//! let y_hat = knn.predict(&x); //! let y_hat = knn.predict(&x).unwrap();
//! ``` //! ```
/// Various algorithms and helper methods that are used elsewhere in SmartCore /// Various algorithms and helper methods that are used elsewhere in SmartCore
+8 -7
View File
@@ -21,7 +21,7 @@
//! &[0.7000, 0.3000, 0.8000], //! &[0.7000, 0.3000, 0.8000],
//! ]); //! ]);
//! //!
//! let evd = A.evd(true); //! let evd = A.evd(true).unwrap();
//! let eigenvectors: DenseMatrix<f64> = evd.V; //! let eigenvectors: DenseMatrix<f64> = evd.V;
//! let eigenvalues: Vec<f64> = evd.d; //! let eigenvalues: Vec<f64> = evd.d;
//! ``` //! ```
@@ -34,6 +34,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)] #![allow(non_snake_case)]
use crate::error::Failed;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use num::complex::Complex; use num::complex::Complex;
@@ -54,14 +55,14 @@ pub struct EVD<T: RealNumber, M: BaseMatrix<T>> {
pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> { pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the eigen decomposition of a square matrix. /// Compute the eigen decomposition of a square matrix.
/// * `symmetric` - whether the matrix is symmetric /// * `symmetric` - whether the matrix is symmetric
fn evd(&self, symmetric: bool) -> EVD<T, Self> { fn evd(&self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
self.clone().evd_mut(symmetric) self.clone().evd_mut(symmetric)
} }
/// Compute the eigen decomposition of a square matrix. The input matrix /// Compute the eigen decomposition of a square matrix. The input matrix
/// will be used for factorization. /// will be used for factorization.
/// * `symmetric` - whether the matrix is symmetric /// * `symmetric` - whether the matrix is symmetric
fn evd_mut(mut self, symmetric: bool) -> EVD<T, Self> { fn evd_mut(mut self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
let (nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();
if ncols != nrows { if ncols != nrows {
panic!("Matrix is not square: {} x {}", nrows, ncols); panic!("Matrix is not square: {} x {}", nrows, ncols);
@@ -92,7 +93,7 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
sort(&mut d, &mut e, &mut V); sort(&mut d, &mut e, &mut V);
} }
EVD { V: V, d: d, e: e } Ok(EVD { V: V, d: d, e: e })
} }
} }
@@ -845,7 +846,7 @@ mod tests {
&[0.6240573, -0.44947578, -0.6391588], &[0.6240573, -0.44947578, -0.6391588],
]); ]);
let evd = A.evd(true); let evd = A.evd(true).unwrap();
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4)); assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() { for i in 0..eigen_values.len() {
@@ -872,7 +873,7 @@ mod tests {
&[0.6952105, 0.43984484, -0.7036135], &[0.6952105, 0.43984484, -0.7036135],
]); ]);
let evd = A.evd(false); let evd = A.evd(false).unwrap();
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4)); assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() { for i in 0..eigen_values.len() {
@@ -902,7 +903,7 @@ mod tests {
&[0.6707, 0.1059, 0.901, -0.6289], &[0.6707, 0.1059, 0.901, -0.6289],
]); ]);
let evd = A.evd(false); let evd = A.evd(false).unwrap();
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4)); assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values_d.len() { for i in 0..eigen_values_d.len() {
+13 -13
View File
@@ -20,7 +20,7 @@
//! &[5., 6., 0.] //! &[5., 6., 0.]
//! ]); //! ]);
//! //!
//! let lu = A.lu(); //! let lu = A.lu().unwrap();
//! let lower: DenseMatrix<f64> = lu.L(); //! let lower: DenseMatrix<f64> = lu.L();
//! let upper: DenseMatrix<f64> = lu.U(); //! let upper: DenseMatrix<f64> = lu.U();
//! ``` //! ```
@@ -36,6 +36,7 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use crate::error::Failed;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -121,7 +122,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
} }
/// Returns matrix inverse /// Returns matrix inverse
pub fn inverse(&self) -> M { pub fn inverse(&self) -> Result<M, Failed> {
let (m, n) = self.LU.shape(); let (m, n) = self.LU.shape();
if m != n { if m != n {
@@ -134,11 +135,10 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
inv.set(i, i, T::one()); inv.set(i, i, T::one());
} }
inv = self.solve(inv); self.solve(inv)
return inv;
} }
fn solve(&self, mut b: M) -> M { fn solve(&self, mut b: M) -> Result<M, Failed> {
let (m, n) = self.LU.shape(); let (m, n) = self.LU.shape();
let (b_m, b_n) = b.shape(); let (b_m, b_n) = b.shape();
@@ -187,20 +187,20 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
} }
} }
b Ok(b)
} }
} }
/// Trait that implements LU decomposition routine for any matrix. /// Trait that implements LU decomposition routine for any matrix.
pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> { pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the LU decomposition of a square matrix. /// Compute the LU decomposition of a square matrix.
fn lu(&self) -> LU<T, Self> { fn lu(&self) -> Result<LU<T, Self>, Failed> {
self.clone().lu_mut() self.clone().lu_mut()
} }
/// Compute the LU decomposition of a square matrix. The input matrix /// Compute the LU decomposition of a square matrix. The input matrix
/// will be used for factorization. /// will be used for factorization.
fn lu_mut(mut self) -> LU<T, Self> { fn lu_mut(mut self) -> Result<LU<T, Self>, Failed> {
let (m, n) = self.shape(); let (m, n) = self.shape();
let mut piv = vec![0; m]; let mut piv = vec![0; m];
@@ -252,12 +252,12 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
} }
} }
LU::new(self, piv, pivsign) Ok(LU::new(self, piv, pivsign))
} }
/// Solves Ax = b /// Solves Ax = b
fn lu_solve_mut(self, b: Self) -> Self { fn lu_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.lu_mut().solve(b) self.lu_mut().and_then(|lu| lu.solve(b))
} }
} }
@@ -275,7 +275,7 @@ mod tests {
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]); DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
let expected_pivot = let expected_pivot =
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]); DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
let lu = a.lu(); let lu = a.lu().unwrap();
assert!(lu.L().approximate_eq(&expected_L, 1e-4)); assert!(lu.L().approximate_eq(&expected_L, 1e-4));
assert!(lu.U().approximate_eq(&expected_U, 1e-4)); assert!(lu.U().approximate_eq(&expected_U, 1e-4));
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4)); assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
@@ -286,7 +286,7 @@ mod tests {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let expected = let expected =
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]); DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
let a_inv = a.lu().inverse(); let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
println!("{}", a_inv); println!("{}", a_inv);
assert!(a_inv.approximate_eq(&expected, 1e-4)); assert!(a_inv.approximate_eq(&expected, 1e-4));
} }
+1 -1
View File
@@ -26,7 +26,7 @@
//! &[0.7000, 0.3000, 0.8000], //! &[0.7000, 0.3000, 0.8000],
//! ]); //! ]);
//! //!
//! let svd = A.svd(); //! let svd = A.svd().unwrap();
//! //!
//! let s: Vec<f64> = svd.s; //! let s: Vec<f64> = svd.s;
//! let v: DenseMatrix<f64> = svd.V; //! let v: DenseMatrix<f64> = svd.V;
+7 -4
View File
@@ -34,8 +34,8 @@
//! 116.9, //! 116.9,
//! ]); //! ]);
//! //!
//! let lr = LinearRegression::fit(&x, &y, Default::default()); //! let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
//! let y_hat = lr.predict(&x); //! let y_hat = lr.predict(&x).unwrap();
//! ``` //! ```
use std::iter::Sum; use std::iter::Sum;
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign}; use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
@@ -777,9 +777,12 @@ mod tests {
solver: LinearRegressionSolverName::QR, solver: LinearRegressionSolverName::QR,
}, },
) )
.predict(&x); .and_then(|lr| lr.predict(&x))
.unwrap();
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x); let y_hat_svd = LinearRegression::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x))
.unwrap();
assert!(y assert!(y
.iter() .iter()
+9 -12
View File
@@ -36,8 +36,8 @@
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. //! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
//! ]); //! ]);
//! //!
//! let lr = LogisticRegression::fit(&x, &y); //! let lr = LogisticRegression::fit(&x, &y).unwrap();
//! let y_hat = lr.predict(&x); //! let y_hat = lr.predict(&x).unwrap();
//! ``` //! ```
use std::iter::Sum; use std::iter::Sum;
use std::ops::AddAssign; use std::ops::AddAssign;
@@ -395,6 +395,7 @@ mod tests {
use super::*; use super::*;
use crate::ensemble::random_forest_regressor::*; use crate::ensemble::random_forest_regressor::*;
use crate::linear::logistic_regression::*; use crate::linear::logistic_regression::*;
use crate::metrics::mean_absolute_error;
use ndarray::{arr1, arr2, Array1, Array2}; use ndarray::{arr1, arr2, Array1, Array2};
#[test] #[test]
@@ -736,9 +737,9 @@ mod tests {
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]); ]);
let lr = LogisticRegression::fit(&x, &y); let lr = LogisticRegression::fit(&x, &y).unwrap();
let y_hat = lr.predict(&x); let y_hat = lr.predict(&x).unwrap();
let error: f64 = y let error: f64 = y
.into_iter() .into_iter()
@@ -774,10 +775,6 @@ mod tests {
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
]); ]);
let expected_y: Vec<f64> = vec![
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
];
let y_hat = RandomForestRegressor::fit( let y_hat = RandomForestRegressor::fit(
&x, &x,
&y, &y,
@@ -789,10 +786,10 @@ mod tests {
m: Option::None, m: Option::None,
}, },
) )
.predict(&x); .unwrap()
.predict(&x)
.unwrap();
for i in 0..y_hat.len() { assert!(mean_absolute_error(&y, &y_hat) < 1.0);
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
}
} }
} }
+14 -14
View File
@@ -15,9 +15,9 @@
//! &[0.7, 0.3, 0.8] //! &[0.7, 0.3, 0.8]
//! ]); //! ]);
//! //!
//! let lu = A.qr(); //! let qr = A.qr().unwrap();
//! let orthogonal: DenseMatrix<f64> = lu.Q(); //! let orthogonal: DenseMatrix<f64> = qr.Q();
//! let triangular: DenseMatrix<f64> = lu.R(); //! let triangular: DenseMatrix<f64> = qr.R();
//! ``` //! ```
//! //!
//! ## References: //! ## References:
@@ -28,10 +28,10 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)] #![allow(non_snake_case)]
use std::fmt::Debug; use crate::error::Failed;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use std::fmt::Debug;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
/// Results of QR decomposition. /// Results of QR decomposition.
@@ -99,7 +99,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
return Q; return Q;
} }
fn solve(&self, mut b: M) -> M { fn solve(&self, mut b: M) -> Result<M, Failed> {
let (m, n) = self.QR.shape(); let (m, n) = self.QR.shape();
let (b_nrows, b_ncols) = b.shape(); let (b_nrows, b_ncols) = b.shape();
@@ -139,20 +139,20 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
} }
} }
b Ok(b)
} }
} }
/// Trait that implements QR decomposition routine for any matrix. /// Trait that implements QR decomposition routine for any matrix.
pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> { pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the QR decomposition of a matrix. /// Compute the QR decomposition of a matrix.
fn qr(&self) -> QR<T, Self> { fn qr(&self) -> Result<QR<T, Self>, Failed> {
self.clone().qr_mut() self.clone().qr_mut()
} }
/// Compute the QR decomposition of a matrix. The input matrix /// Compute the QR decomposition of a matrix. The input matrix
/// will be used for factorization. /// will be used for factorization.
fn qr_mut(mut self) -> QR<T, Self> { fn qr_mut(mut self) -> Result<QR<T, Self>, Failed> {
let (m, n) = self.shape(); let (m, n) = self.shape();
let mut r_diagonal: Vec<T> = vec![T::zero(); n]; let mut r_diagonal: Vec<T> = vec![T::zero(); n];
@@ -186,12 +186,12 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
r_diagonal[k] = -nrm; r_diagonal[k] = -nrm;
} }
QR::new(self, r_diagonal) Ok(QR::new(self, r_diagonal))
} }
/// Solves Ax = b /// Solves Ax = b
fn qr_solve_mut(self, b: Self) -> Self { fn qr_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.qr_mut().solve(b) self.qr_mut().and_then(|qr| qr.solve(b))
} }
} }
@@ -213,7 +213,7 @@ mod tests {
&[0.0, -0.3064, 0.0682], &[0.0, -0.3064, 0.0682],
&[0.0, 0.0, -0.1999], &[0.0, 0.0, -0.1999],
]); ]);
let qr = a.qr(); let qr = a.qr().unwrap();
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4)); assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4)); assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
} }
@@ -227,7 +227,7 @@ mod tests {
&[0.8783784, 2.2297297], &[0.8783784, 2.2297297],
&[0.4729730, 0.6621622], &[0.4729730, 0.6621622],
]); ]);
let w = a.qr_solve_mut(b); let w = a.qr_solve_mut(b).unwrap();
assert!(w.approximate_eq(&expected_w, 1e-2)); assert!(w.approximate_eq(&expected_w, 1e-2));
} }
} }
+14 -13
View File
@@ -19,7 +19,7 @@
//! &[0.7, 0.3, 0.8] //! &[0.7, 0.3, 0.8]
//! ]); //! ]);
//! //!
//! let svd = A.svd(); //! let svd = A.svd().unwrap();
//! let u: DenseMatrix<f64> = svd.U; //! let u: DenseMatrix<f64> = svd.U;
//! let v: DenseMatrix<f64> = svd.V; //! let v: DenseMatrix<f64> = svd.V;
//! let s: Vec<f64> = svd.s; //! let s: Vec<f64> = svd.s;
@@ -33,6 +33,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)] #![allow(non_snake_case)]
use crate::error::Failed;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use std::fmt::Debug; use std::fmt::Debug;
@@ -55,23 +56,23 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
/// Trait that implements SVD decomposition routine for any matrix. /// Trait that implements SVD decomposition routine for any matrix.
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> { pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Solves Ax = b. Overrides original matrix in the process. /// Solves Ax = b. Overrides original matrix in the process.
fn svd_solve_mut(self, b: Self) -> Self { fn svd_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.svd_mut().solve(b) self.svd_mut().and_then(|svd| svd.solve(b))
} }
/// Solves Ax = b /// Solves Ax = b
fn svd_solve(&self, b: Self) -> Self { fn svd_solve(&self, b: Self) -> Result<Self, Failed> {
self.svd().solve(b) self.svd().and_then(|svd| svd.solve(b))
} }
/// Compute the SVD decomposition of a matrix. /// Compute the SVD decomposition of a matrix.
fn svd(&self) -> SVD<T, Self> { fn svd(&self) -> Result<SVD<T, Self>, Failed> {
self.clone().svd_mut() self.clone().svd_mut()
} }
/// Compute the SVD decomposition of a matrix. The input matrix /// Compute the SVD decomposition of a matrix. The input matrix
/// will be used for factorization. /// will be used for factorization.
fn svd_mut(self) -> SVD<T, Self> { fn svd_mut(self) -> Result<SVD<T, Self>, Failed> {
let mut U = self; let mut U = self;
let (m, n) = U.shape(); let (m, n) = U.shape();
@@ -406,7 +407,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
} }
} }
SVD::new(U, v, w) Ok(SVD::new(U, v, w))
} }
} }
@@ -427,7 +428,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
} }
} }
pub(crate) fn solve(&self, mut b: M) -> M { pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
let p = b.shape().1; let p = b.shape().1;
if self.U.shape().0 != b.shape().0 { if self.U.shape().0 != b.shape().0 {
@@ -460,7 +461,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
} }
} }
b Ok(b)
} }
} }
@@ -491,7 +492,7 @@ mod tests {
&[0.6240573, -0.44947578, -0.6391588], &[0.6240573, -0.44947578, -0.6391588],
]); ]);
let svd = A.svd(); let svd = A.svd().unwrap();
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4)); assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4)); assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
@@ -692,7 +693,7 @@ mod tests {
], ],
]); ]);
let svd = A.svd(); let svd = A.svd().unwrap();
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4)); assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4)); assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
@@ -707,7 +708,7 @@ mod tests {
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]); let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = let expected_w =
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]); DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
let w = a.svd_solve_mut(b); let w = a.svd_solve_mut(b).unwrap();
assert!(w.approximate_eq(&expected_w, 1e-2)); assert!(w.approximate_eq(&expected_w, 1e-2));
} }
} }
+19 -13
View File
@@ -47,9 +47,9 @@
//! //!
//! let lr = LinearRegression::fit(&x, &y, LinearRegressionParameters { //! let lr = LinearRegression::fit(&x, &y, LinearRegressionParameters {
//! solver: LinearRegressionSolverName::QR, // or SVD //! solver: LinearRegressionSolverName::QR, // or SVD
//! }); //! }).unwrap();
//! //!
//! let y_hat = lr.predict(&x); //! let y_hat = lr.predict(&x).unwrap();
//! ``` //! ```
//! //!
//! ## References: //! ## References:
@@ -64,6 +64,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -115,39 +116,41 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
x: &M, x: &M,
y: &M::RowVector, y: &M::RowVector,
parameters: LinearRegressionParameters, parameters: LinearRegressionParameters,
) -> LinearRegression<T, M> { ) -> Result<LinearRegression<T, M>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let b = y_m.transpose(); let b = y_m.transpose();
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let (y_nrows, _) = b.shape(); let (y_nrows, _) = b.shape();
if x_nrows != y_nrows { if x_nrows != y_nrows {
panic!("Number of rows of X doesn't match number of rows of Y"); return Err(Failed::fit(&format!(
"Number of rows of X doesn't match number of rows of Y"
)));
} }
let a = x.h_stack(&M::ones(x_nrows, 1)); let a = x.h_stack(&M::ones(x_nrows, 1));
let w = match parameters.solver { let w = match parameters.solver {
LinearRegressionSolverName::QR => a.qr_solve_mut(b), LinearRegressionSolverName::QR => a.qr_solve_mut(b)?,
LinearRegressionSolverName::SVD => a.svd_solve_mut(b), LinearRegressionSolverName::SVD => a.svd_solve_mut(b)?,
}; };
let wights = w.slice(0..num_attributes, 0..1); let wights = w.slice(0..num_attributes, 0..1);
LinearRegression { Ok(LinearRegression {
intercept: w.get(num_attributes, 0), intercept: w.get(num_attributes, 0),
coefficients: wights, coefficients: wights,
solver: parameters.solver, solver: parameters.solver,
} })
} }
/// Predict target values from `x` /// Predict target values from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &M) -> M::RowVector { pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let (nrows, _) = x.shape(); let (nrows, _) = x.shape();
let mut y_hat = x.matmul(&self.coefficients); let mut y_hat = x.matmul(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept)); y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
y_hat.transpose().to_row_vector() Ok(y_hat.transpose().to_row_vector())
} }
/// Get estimates regression coefficients /// Get estimates regression coefficients
@@ -199,9 +202,12 @@ mod tests {
solver: LinearRegressionSolverName::QR, solver: LinearRegressionSolverName::QR,
}, },
) )
.predict(&x); .and_then(|lr| lr.predict(&x))
.unwrap();
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x); let y_hat_svd = LinearRegression::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x))
.unwrap();
assert!(y assert!(y
.iter() .iter()
@@ -239,7 +245,7 @@ mod tests {
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
]; ];
let lr = LinearRegression::fit(&x, &y, Default::default()); let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> = let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap(); serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
+22 -16
View File
@@ -40,9 +40,9 @@
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., //! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ]; //! ];
//! //!
//! let lr = LogisticRegression::fit(&x, &y); //! let lr = LogisticRegression::fit(&x, &y).unwrap();
//! //!
//! let y_hat = lr.predict(&x); //! let y_hat = lr.predict(&x).unwrap();
//! ``` //! ```
//! //!
//! ## References: //! ## References:
@@ -57,6 +57,7 @@ use std::marker::PhantomData;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::optimization::first_order::lbfgs::LBFGS; use crate::optimization::first_order::lbfgs::LBFGS;
@@ -208,13 +209,15 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
/// Fits Logistic Regression to your data. /// Fits Logistic Regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target class values /// * `y` - target class values
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M> { pub fn fit(x: &M, y: &M::RowVector) -> Result<LogisticRegression<T, M>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let (_, y_nrows) = y_m.shape(); let (_, y_nrows) = y_m.shape();
if x_nrows != y_nrows { if x_nrows != y_nrows {
panic!("Number of rows of X doesn't match number of rows of Y"); return Err(Failed::fit(&format!(
"Number of rows of X doesn't match number of rows of Y"
)));
} }
let classes = y_m.unique(); let classes = y_m.unique();
@@ -229,7 +232,10 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
} }
if k < 2 { if k < 2 {
panic!("Incorrect number of classes: {}", k); Err(Failed::fit(&format!(
"incorrect number of classes: {}. Should be >= 2.",
k
)))
} else if k == 2 { } else if k == 2 {
let x0 = M::zeros(1, num_attributes + 1); let x0 = M::zeros(1, num_attributes + 1);
@@ -241,12 +247,12 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
LogisticRegression { Ok(LogisticRegression {
weights: result.x, weights: result.x,
classes: classes, classes: classes,
num_attributes: num_attributes, num_attributes: num_attributes,
num_classes: k, num_classes: k,
} })
} else { } else {
let x0 = M::zeros(1, (num_attributes + 1) * k); let x0 = M::zeros(1, (num_attributes + 1) * k);
@@ -261,18 +267,18 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
let weights = result.x.reshape(k, num_attributes + 1); let weights = result.x.reshape(k, num_attributes + 1);
LogisticRegression { Ok(LogisticRegression {
weights: weights, weights: weights,
classes: classes, classes: classes,
num_attributes: num_attributes, num_attributes: num_attributes,
num_classes: k, num_classes: k,
} })
} }
} }
/// Predict class labels for samples in `x`. /// Predict class labels for samples in `x`.
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &M) -> M::RowVector { pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let n = x.shape().0; let n = x.shape().0;
let mut result = M::zeros(1, n); let mut result = M::zeros(1, n);
if self.num_classes == 2 { if self.num_classes == 2 {
@@ -297,7 +303,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
result.set(0, i, self.classes[class_idxs[i]]); result.set(0, i, self.classes[class_idxs[i]]);
} }
} }
result.to_row_vector() Ok(result.to_row_vector())
} }
/// Get estimates regression coefficients /// Get estimates regression coefficients
@@ -444,7 +450,7 @@ mod tests {
]); ]);
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.]; let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y); let lr = LogisticRegression::fit(&x, &y).unwrap();
assert_eq!(lr.coefficients().shape(), (3, 2)); assert_eq!(lr.coefficients().shape(), (3, 2));
assert_eq!(lr.intercept().shape(), (3, 1)); assert_eq!(lr.intercept().shape(), (3, 1));
@@ -452,7 +458,7 @@ mod tests {
assert!((lr.coefficients().get(0, 0) - 0.0435).abs() < 1e-4); assert!((lr.coefficients().get(0, 0) - 0.0435).abs() < 1e-4);
assert!((lr.intercept().get(0, 0) - 0.1250).abs() < 1e-4); assert!((lr.intercept().get(0, 0) - 0.1250).abs() < 1e-4);
let y_hat = lr.predict(&x); let y_hat = lr.predict(&x).unwrap();
assert_eq!( assert_eq!(
y_hat, y_hat,
@@ -481,7 +487,7 @@ mod tests {
]); ]);
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.]; let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y); let lr = LogisticRegression::fit(&x, &y).unwrap();
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> = let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap(); serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
@@ -517,9 +523,9 @@ mod tests {
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]; ];
let lr = LogisticRegression::fit(&x, &y); let lr = LogisticRegression::fit(&x, &y).unwrap();
let y_hat = lr.predict(&x); let y_hat = lr.predict(&x).unwrap();
let error: f64 = y let error: f64 = y
.into_iter() .into_iter()
+2 -2
View File
@@ -66,7 +66,7 @@ impl<T: RealNumber, M: Matrix<T>> Mahalanobis<T, M> {
/// * `data` - a matrix of _NxM_ where _N_ is number of observations and _M_ is number of attributes /// * `data` - a matrix of _NxM_ where _N_ is number of observations and _M_ is number of attributes
pub fn new(data: &M) -> Mahalanobis<T, M> { pub fn new(data: &M) -> Mahalanobis<T, M> {
let sigma = data.cov(); let sigma = data.cov();
let sigmaInv = sigma.lu().inverse(); let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
Mahalanobis { Mahalanobis {
sigma: sigma, sigma: sigma,
sigmaInv: sigmaInv, sigmaInv: sigmaInv,
@@ -78,7 +78,7 @@ impl<T: RealNumber, M: Matrix<T>> Mahalanobis<T, M> {
/// * `cov` - a covariance matrix /// * `cov` - a covariance matrix
pub fn new_from_covariance(cov: &M) -> Mahalanobis<T, M> { pub fn new_from_covariance(cov: &M) -> Mahalanobis<T, M> {
let sigma = cov.clone(); let sigma = cov.clone();
let sigmaInv = sigma.lu().inverse(); let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
Mahalanobis { Mahalanobis {
sigma: sigma, sigma: sigma,
sigmaInv: sigmaInv, sigmaInv: sigmaInv,
+2 -2
View File
@@ -41,9 +41,9 @@
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., //! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ]; //! ];
//! //!
//! let lr = LogisticRegression::fit(&x, &y); //! let lr = LogisticRegression::fit(&x, &y).unwrap();
//! //!
//! let y_hat = lr.predict(&x); //! let y_hat = lr.predict(&x).unwrap();
//! //!
//! let acc = ClassificationMetrics::accuracy().get_score(&y, &y_hat); //! let acc = ClassificationMetrics::accuracy().get_score(&y, &y_hat);
//! // or //! // or
+31 -28
View File
@@ -25,8 +25,8 @@
//! &[9., 10.]]); //! &[9., 10.]]);
//! let y = vec![2., 2., 2., 3., 3.]; //your class labels //! let y = vec![2., 2., 2., 3., 3.]; //your class labels
//! //!
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()); //! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
//! let y_hat = knn.predict(&x); //! let y_hat = knn.predict(&x).unwrap();
//! ``` //! ```
//! //!
//! variable `y_hat` will hold a vector with estimates of class labels //! variable `y_hat` will hold a vector with estimates of class labels
@@ -34,6 +34,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::{row_iter, Matrix}; use crate::linalg::{row_iter, Matrix};
use crate::math::distance::Distance; use crate::math::distance::Distance;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -106,7 +107,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
y: &M::RowVector, y: &M::RowVector,
distance: D, distance: D,
parameters: KNNClassifierParameters, parameters: KNNClassifierParameters,
) -> KNNClassifier<T, D> { ) -> Result<KNNClassifier<T, D>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_n) = y_m.shape(); let (_, y_n) = y_m.shape();
@@ -122,43 +123,44 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
yi[i] = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap();
} }
assert!( if x_n != y_n {
x_n == y_n, return Err(Failed::fit(&format!(
format!(
"Size of x should equal size of y; |x|=[{}], |y|=[{}]", "Size of x should equal size of y; |x|=[{}], |y|=[{}]",
x_n, y_n x_n, y_n
) )));
); }
assert!( if parameters.k <= 1 {
parameters.k > 1, return Err(Failed::fit(&format!(
format!("k should be > 1, k=[{}]", parameters.k) "k should be > 1, k=[{}]",
); parameters.k
)));
}
KNNClassifier { Ok(KNNClassifier {
classes: classes, classes: classes,
y: yi, y: yi,
k: parameters.k, k: parameters.k,
knn_algorithm: parameters.algorithm.fit(data, distance), knn_algorithm: parameters.algorithm.fit(data, distance)?,
weight: parameters.weight, weight: parameters.weight,
} })
} }
/// Estimates the class labels for the provided data. /// Estimates the class labels for the provided data.
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features. /// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
/// Returns a vector of size N with class estimates. /// Returns a vector of size N with class estimates.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
row_iter(x) for (i, x) in row_iter(x).enumerate() {
.enumerate() result.set(0, i, self.classes[self.predict_for_row(x)?]);
.for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)])); }
result.to_row_vector() Ok(result.to_row_vector())
} }
fn predict_for_row(&self, x: Vec<T>) -> usize { fn predict_for_row(&self, x: Vec<T>) -> Result<usize, Failed> {
let search_result = self.knn_algorithm.find(&x, self.k); let search_result = self.knn_algorithm.find(&x, self.k)?;
let weights = self let weights = self
.weight .weight
@@ -176,7 +178,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
} }
} }
max_i Ok(max_i)
} }
} }
@@ -191,8 +193,8 @@ mod tests {
let x = let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]); DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
let y = vec![2., 2., 2., 3., 3.]; let y = vec![2., 2., 2., 3., 3.];
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()); let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
let y_hat = knn.predict(&x); let y_hat = knn.predict(&x).unwrap();
assert_eq!(5, Vec::len(&y_hat)); assert_eq!(5, Vec::len(&y_hat));
assert_eq!(y.to_vec(), y_hat); assert_eq!(y.to_vec(), y_hat);
} }
@@ -210,8 +212,9 @@ mod tests {
algorithm: KNNAlgorithmName::LinearSearch, algorithm: KNNAlgorithmName::LinearSearch,
weight: KNNWeightFunction::Distance, weight: KNNWeightFunction::Distance,
}, },
); )
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]])); .unwrap();
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]])).unwrap();
assert_eq!(vec![3.0], y_hat); assert_eq!(vec![3.0], y_hat);
} }
@@ -221,7 +224,7 @@ mod tests {
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]); DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
let y = vec![2., 2., 2., 3., 3.]; let y = vec![2., 2., 2., 3., 3.];
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()); let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap(); let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
+31 -28
View File
@@ -27,8 +27,8 @@
//! &[5., 5.]]); //! &[5., 5.]]);
//! let y = vec![1., 2., 3., 4., 5.]; //your target values //! let y = vec![1., 2., 3., 4., 5.]; //your target values
//! //!
//! let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()); //! let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
//! let y_hat = knn.predict(&x); //! let y_hat = knn.predict(&x).unwrap();
//! ``` //! ```
//! //!
//! variable `y_hat` will hold predicted value //! variable `y_hat` will hold predicted value
@@ -36,6 +36,7 @@
//! //!
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::{row_iter, BaseVector, Matrix}; use crate::linalg::{row_iter, BaseVector, Matrix};
use crate::math::distance::Distance; use crate::math::distance::Distance;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -99,7 +100,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
y: &M::RowVector, y: &M::RowVector,
distance: D, distance: D,
parameters: KNNRegressorParameters, parameters: KNNRegressorParameters,
) -> KNNRegressor<T, D> { ) -> Result<KNNRegressor<T, D>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_n) = y_m.shape(); let (_, y_n) = y_m.shape();
@@ -107,42 +108,43 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
let data = row_iter(x).collect(); let data = row_iter(x).collect();
assert!( if x_n != y_n {
x_n == y_n, return Err(Failed::fit(&format!(
format!(
"Size of x should equal size of y; |x|=[{}], |y|=[{}]", "Size of x should equal size of y; |x|=[{}], |y|=[{}]",
x_n, y_n x_n, y_n
) )));
); }
assert!( if parameters.k <= 1 {
parameters.k > 1, return Err(Failed::fit(&format!(
format!("k should be > 1, k=[{}]", parameters.k) "k should be > 1, k=[{}]",
); parameters.k
)));
}
KNNRegressor { Ok(KNNRegressor {
y: y.to_vec(), y: y.to_vec(),
k: parameters.k, k: parameters.k,
knn_algorithm: parameters.algorithm.fit(data, distance), knn_algorithm: parameters.algorithm.fit(data, distance)?,
weight: parameters.weight, weight: parameters.weight,
} })
} }
/// Predict the target for the provided data. /// Predict the target for the provided data.
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features. /// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
/// Returns a vector of size N with estimates. /// Returns a vector of size N with estimates.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
row_iter(x) for (i, x) in row_iter(x).enumerate() {
.enumerate() result.set(0, i, self.predict_for_row(x)?);
.for_each(|(i, x)| result.set(0, i, self.predict_for_row(x))); }
result.to_row_vector() Ok(result.to_row_vector())
} }
fn predict_for_row(&self, x: Vec<T>) -> T { fn predict_for_row(&self, x: Vec<T>) -> Result<T, Failed> {
let search_result = self.knn_algorithm.find(&x, self.k); let search_result = self.knn_algorithm.find(&x, self.k)?;
let mut result = T::zero(); let mut result = T::zero();
let weights = self let weights = self
@@ -154,7 +156,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
result = result + self.y[r.0] * (*w / w_sum); result = result + self.y[r.0] * (*w / w_sum);
} }
result Ok(result)
} }
} }
@@ -179,8 +181,9 @@ mod tests {
algorithm: KNNAlgorithmName::LinearSearch, algorithm: KNNAlgorithmName::LinearSearch,
weight: KNNWeightFunction::Distance, weight: KNNWeightFunction::Distance,
}, },
); )
let y_hat = knn.predict(&x); .unwrap();
let y_hat = knn.predict(&x).unwrap();
assert_eq!(5, Vec::len(&y_hat)); assert_eq!(5, Vec::len(&y_hat));
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON); assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON);
@@ -193,8 +196,8 @@ mod tests {
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]); DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
let y: Vec<f64> = vec![1., 2., 3., 4., 5.]; let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
let y_exp = vec![2., 2., 3., 4., 4.]; let y_exp = vec![2., 2., 3., 4., 4.];
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()); let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
let y_hat = knn.predict(&x); let y_hat = knn.predict(&x).unwrap();
assert_eq!(5, Vec::len(&y_hat)); assert_eq!(5, Vec::len(&y_hat));
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - y_exp[i]).abs() < 1e-7); assert!((y_hat[i] - y_exp[i]).abs() < 1e-7);
@@ -207,7 +210,7 @@ mod tests {
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]); DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
let y = vec![1., 2., 3., 4., 5.]; let y = vec![1., 2., 3., 4., 5.];
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()); let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap(); let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
+7 -4
View File
@@ -34,6 +34,7 @@
use crate::algorithm::neighbour::cover_tree::CoverTree; use crate::algorithm::neighbour::cover_tree::CoverTree;
use crate::algorithm::neighbour::linear_search::LinearKNNSearch; use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::error::Failed;
use crate::math::distance::Distance; use crate::math::distance::Distance;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -93,18 +94,20 @@ impl KNNAlgorithmName {
&self, &self,
data: Vec<Vec<T>>, data: Vec<Vec<T>>,
distance: D, distance: D,
) -> KNNAlgorithm<T, D> { ) -> Result<KNNAlgorithm<T, D>, Failed> {
match *self { match *self {
KNNAlgorithmName::LinearSearch => { KNNAlgorithmName::LinearSearch => {
KNNAlgorithm::LinearSearch(LinearKNNSearch::new(data, distance)) LinearKNNSearch::new(data, distance).map(|a| KNNAlgorithm::LinearSearch(a))
}
KNNAlgorithmName::CoverTree => {
CoverTree::new(data, distance).map(|a| KNNAlgorithm::CoverTree(a))
} }
KNNAlgorithmName::CoverTree => KNNAlgorithm::CoverTree(CoverTree::new(data, distance)),
} }
} }
} }
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> { impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
fn find(&self, from: &Vec<T>, k: usize) -> Vec<(usize, T)> { fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T)>, Failed> {
match *self { match *self {
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k), KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k), KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
+20 -11
View File
@@ -50,9 +50,9 @@
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0., //! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; //! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
//! //!
//! let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()); //! let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
//! //!
//! let y_hat = tree.predict(&x); // use the same data for prediction //! let y_hat = tree.predict(&x).unwrap(); // use the same data for prediction
//! ``` //! ```
//! //!
//! //!
@@ -71,6 +71,7 @@ use rand::seq::SliceRandom;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::sort::quick_sort::QuickArgSort; use crate::algorithm::sort::quick_sort::QuickArgSort;
use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -276,7 +277,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
x: &M, x: &M,
y: &M::RowVector, y: &M::RowVector,
parameters: DecisionTreeClassifierParameters, parameters: DecisionTreeClassifierParameters,
) -> DecisionTreeClassifier<T> { ) -> Result<DecisionTreeClassifier<T>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
@@ -288,14 +289,17 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
samples: Vec<usize>, samples: Vec<usize>,
mtry: usize, mtry: usize,
parameters: DecisionTreeClassifierParameters, parameters: DecisionTreeClassifierParameters,
) -> DecisionTreeClassifier<T> { ) -> Result<DecisionTreeClassifier<T>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
let classes = y_m.unique(); let classes = y_m.unique();
let k = classes.len(); let k = classes.len();
if k < 2 { if k < 2 {
panic!("Incorrect number of classes: {}. Should be >= 2.", k); return Err(Failed::fit(&format!(
"Incorrect number of classes: {}. Should be >= 2.",
k
)));
} }
let mut yi: Vec<usize> = vec![0; y_ncols]; let mut yi: Vec<usize> = vec![0; y_ncols];
@@ -343,12 +347,12 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
}; };
} }
tree Ok(tree)
} }
/// Predict class value for `x`. /// Predict class value for `x`.
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -357,7 +361,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
result.set(0, i, self.classes[self.predict_for_row(x, i)]); result.set(0, i, self.classes[self.predict_for_row(x, i)]);
} }
result.to_row_vector() Ok(result.to_row_vector())
} }
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize { pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
@@ -637,7 +641,9 @@ mod tests {
assert_eq!( assert_eq!(
y, y,
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x) DecisionTreeClassifier::fit(&x, &y, Default::default())
.and_then(|t| t.predict(&x))
.unwrap()
); );
assert_eq!( assert_eq!(
@@ -652,6 +658,7 @@ mod tests {
min_samples_split: 2 min_samples_split: 2
} }
) )
.unwrap()
.depth .depth
); );
} }
@@ -686,7 +693,9 @@ mod tests {
assert_eq!( assert_eq!(
y, y,
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x) DecisionTreeClassifier::fit(&x, &y, Default::default())
.and_then(|t| t.predict(&x))
.unwrap()
); );
} }
@@ -718,7 +727,7 @@ mod tests {
1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
]; ];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()); let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
let deserialized_tree: DecisionTreeClassifier<f64> = let deserialized_tree: DecisionTreeClassifier<f64> =
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap(); bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
+16 -16
View File
@@ -45,9 +45,9 @@
//! 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9, //! 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9,
//! ]; //! ];
//! //!
//! let tree = DecisionTreeRegressor::fit(&x, &y, Default::default()); //! let tree = DecisionTreeRegressor::fit(&x, &y, Default::default()).unwrap();
//! //!
//! let y_hat = tree.predict(&x); // use the same data for prediction //! let y_hat = tree.predict(&x).unwrap(); // use the same data for prediction
//! ``` //! ```
//! //!
//! ## References: //! ## References:
@@ -66,6 +66,7 @@ use rand::seq::SliceRandom;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::sort::quick_sort::QuickArgSort; use crate::algorithm::sort::quick_sort::QuickArgSort;
use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -196,7 +197,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
x: &M, x: &M,
y: &M::RowVector, y: &M::RowVector,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
) -> DecisionTreeRegressor<T> { ) -> Result<DecisionTreeRegressor<T>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
@@ -208,16 +209,11 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
samples: Vec<usize>, samples: Vec<usize>,
mtry: usize, mtry: usize,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
) -> DecisionTreeRegressor<T> { ) -> Result<DecisionTreeRegressor<T>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
let classes = y_m.unique();
let k = classes.len();
if k < 2 {
panic!("Incorrect number of classes: {}. Should be >= 2.", k);
}
let mut nodes: Vec<Node<T>> = Vec::new(); let mut nodes: Vec<Node<T>> = Vec::new();
@@ -257,12 +253,12 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
}; };
} }
tree Ok(tree)
} }
/// Predict regression value for `x`. /// Predict regression value for `x`.
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -271,7 +267,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
result.set(0, i, self.predict_for_row(x, i)); result.set(0, i, self.predict_for_row(x, i));
} }
result.to_row_vector() Ok(result.to_row_vector())
} }
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T { pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
@@ -498,7 +494,9 @@ mod tests {
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
]; ];
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x); let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default())
.and_then(|t| t.predict(&x))
.unwrap();
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - y[i]).abs() < 0.1); assert!((y_hat[i] - y[i]).abs() < 0.1);
@@ -517,7 +515,8 @@ mod tests {
min_samples_split: 6, min_samples_split: 6,
}, },
) )
.predict(&x); .and_then(|t| t.predict(&x))
.unwrap();
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 0.1); assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
@@ -536,7 +535,8 @@ mod tests {
min_samples_split: 3, min_samples_split: 3,
}, },
) )
.predict(&x); .and_then(|t| t.predict(&x))
.unwrap();
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 0.1); assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
@@ -568,7 +568,7 @@ mod tests {
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
]; ];
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default()); let tree = DecisionTreeRegressor::fit(&x, &y, Default::default()).unwrap();
let deserialized_tree: DecisionTreeRegressor<f64> = let deserialized_tree: DecisionTreeRegressor<f64> =
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap(); bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();