feat: refactoring, adds Result to most public API
This commit is contained in:
@@ -25,8 +25,8 @@
|
||||
//! &[9., 10.]]);
|
||||
//! let y = vec![2., 2., 2., 3., 3.]; //your class labels
|
||||
//!
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
//! let y_hat = knn.predict(&x);
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! variable `y_hat` will hold a vector with estimates of class labels
|
||||
@@ -34,6 +34,7 @@
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::{row_iter, Matrix};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -106,7 +107,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
y: &M::RowVector,
|
||||
distance: D,
|
||||
parameters: KNNClassifierParameters,
|
||||
) -> KNNClassifier<T, D> {
|
||||
) -> Result<KNNClassifier<T, D>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
let (_, y_n) = y_m.shape();
|
||||
@@ -122,43 +123,44 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||
}
|
||||
|
||||
assert!(
|
||||
x_n == y_n,
|
||||
format!(
|
||||
if x_n != y_n {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||
x_n, y_n
|
||||
)
|
||||
);
|
||||
)));
|
||||
}
|
||||
|
||||
assert!(
|
||||
parameters.k > 1,
|
||||
format!("k should be > 1, k=[{}]", parameters.k)
|
||||
);
|
||||
if parameters.k <= 1 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"k should be > 1, k=[{}]",
|
||||
parameters.k
|
||||
)));
|
||||
}
|
||||
|
||||
KNNClassifier {
|
||||
Ok(KNNClassifier {
|
||||
classes: classes,
|
||||
y: yi,
|
||||
k: parameters.k,
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance),
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance)?,
|
||||
weight: parameters.weight,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
row_iter(x)
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)]));
|
||||
for (i, x) in row_iter(x).enumerate() {
|
||||
result.set(0, i, self.classes[self.predict_for_row(x)?]);
|
||||
}
|
||||
|
||||
result.to_row_vector()
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
fn predict_for_row(&self, x: Vec<T>) -> usize {
|
||||
let search_result = self.knn_algorithm.find(&x, self.k);
|
||||
fn predict_for_row(&self, x: Vec<T>) -> Result<usize, Failed> {
|
||||
let search_result = self.knn_algorithm.find(&x, self.k)?;
|
||||
|
||||
let weights = self
|
||||
.weight
|
||||
@@ -176,7 +178,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
max_i
|
||||
Ok(max_i)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,8 +193,8 @@ mod tests {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
let y_hat = knn.predict(&x);
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
assert_eq!(y.to_vec(), y_hat);
|
||||
}
|
||||
@@ -210,8 +212,9 @@ mod tests {
|
||||
algorithm: KNNAlgorithmName::LinearSearch,
|
||||
weight: KNNWeightFunction::Distance,
|
||||
},
|
||||
);
|
||||
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]]));
|
||||
)
|
||||
.unwrap();
|
||||
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]])).unwrap();
|
||||
assert_eq!(vec![3.0], y_hat);
|
||||
}
|
||||
|
||||
@@ -221,7 +224,7 @@ mod tests {
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@
|
||||
//! &[5., 5.]]);
|
||||
//! let y = vec![1., 2., 3., 4., 5.]; //your target values
|
||||
//!
|
||||
//! let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
//! let y_hat = knn.predict(&x);
|
||||
//! let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! variable `y_hat` will hold predicted value
|
||||
@@ -36,6 +36,7 @@
|
||||
//!
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::{row_iter, BaseVector, Matrix};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -99,7 +100,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||
y: &M::RowVector,
|
||||
distance: D,
|
||||
parameters: KNNRegressorParameters,
|
||||
) -> KNNRegressor<T, D> {
|
||||
) -> Result<KNNRegressor<T, D>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
let (_, y_n) = y_m.shape();
|
||||
@@ -107,42 +108,43 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||
|
||||
let data = row_iter(x).collect();
|
||||
|
||||
assert!(
|
||||
x_n == y_n,
|
||||
format!(
|
||||
if x_n != y_n {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||
x_n, y_n
|
||||
)
|
||||
);
|
||||
)));
|
||||
}
|
||||
|
||||
assert!(
|
||||
parameters.k > 1,
|
||||
format!("k should be > 1, k=[{}]", parameters.k)
|
||||
);
|
||||
if parameters.k <= 1 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"k should be > 1, k=[{}]",
|
||||
parameters.k
|
||||
)));
|
||||
}
|
||||
|
||||
KNNRegressor {
|
||||
Ok(KNNRegressor {
|
||||
y: y.to_vec(),
|
||||
k: parameters.k,
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance),
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance)?,
|
||||
weight: parameters.weight,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict the target for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
/// Returns a vector of size N with estimates.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
row_iter(x)
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| result.set(0, i, self.predict_for_row(x)));
|
||||
for (i, x) in row_iter(x).enumerate() {
|
||||
result.set(0, i, self.predict_for_row(x)?);
|
||||
}
|
||||
|
||||
result.to_row_vector()
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
fn predict_for_row(&self, x: Vec<T>) -> T {
|
||||
let search_result = self.knn_algorithm.find(&x, self.k);
|
||||
fn predict_for_row(&self, x: Vec<T>) -> Result<T, Failed> {
|
||||
let search_result = self.knn_algorithm.find(&x, self.k)?;
|
||||
let mut result = T::zero();
|
||||
|
||||
let weights = self
|
||||
@@ -154,7 +156,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||
result = result + self.y[r.0] * (*w / w_sum);
|
||||
}
|
||||
|
||||
result
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,8 +181,9 @@ mod tests {
|
||||
algorithm: KNNAlgorithmName::LinearSearch,
|
||||
weight: KNNWeightFunction::Distance,
|
||||
},
|
||||
);
|
||||
let y_hat = knn.predict(&x);
|
||||
)
|
||||
.unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON);
|
||||
@@ -193,8 +196,8 @@ mod tests {
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||
let y_exp = vec![2., 2., 3., 4., 4.];
|
||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
let y_hat = knn.predict(&x);
|
||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - y_exp[i]).abs() < 1e-7);
|
||||
@@ -207,7 +210,7 @@ mod tests {
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![1., 2., 3., 4., 5.];
|
||||
|
||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
|
||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||
use crate::error::Failed;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -93,18 +94,20 @@ impl KNNAlgorithmName {
|
||||
&self,
|
||||
data: Vec<Vec<T>>,
|
||||
distance: D,
|
||||
) -> KNNAlgorithm<T, D> {
|
||||
) -> Result<KNNAlgorithm<T, D>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithmName::LinearSearch => {
|
||||
KNNAlgorithm::LinearSearch(LinearKNNSearch::new(data, distance))
|
||||
LinearKNNSearch::new(data, distance).map(|a| KNNAlgorithm::LinearSearch(a))
|
||||
}
|
||||
KNNAlgorithmName::CoverTree => {
|
||||
CoverTree::new(data, distance).map(|a| KNNAlgorithm::CoverTree(a))
|
||||
}
|
||||
KNNAlgorithmName::CoverTree => KNNAlgorithm::CoverTree(CoverTree::new(data, distance)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||
fn find(&self, from: &Vec<T>, k: usize) -> Vec<(usize, T)> {
|
||||
fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T)>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
||||
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
||||
|
||||
Reference in New Issue
Block a user