feat: refactoring, adds Result to most public API
This commit is contained in:
@@ -25,8 +25,8 @@
|
||||
//! &[9., 10.]]);
|
||||
//! let y = vec![2., 2., 2., 3., 3.]; //your class labels
|
||||
//!
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
//! let y_hat = knn.predict(&x);
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! variable `y_hat` will hold a vector with estimates of class labels
|
||||
@@ -34,6 +34,7 @@
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::{row_iter, Matrix};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -106,7 +107,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
y: &M::RowVector,
|
||||
distance: D,
|
||||
parameters: KNNClassifierParameters,
|
||||
) -> KNNClassifier<T, D> {
|
||||
) -> Result<KNNClassifier<T, D>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
let (_, y_n) = y_m.shape();
|
||||
@@ -122,43 +123,44 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||
}
|
||||
|
||||
assert!(
|
||||
x_n == y_n,
|
||||
format!(
|
||||
if x_n != y_n {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||
x_n, y_n
|
||||
)
|
||||
);
|
||||
)));
|
||||
}
|
||||
|
||||
assert!(
|
||||
parameters.k > 1,
|
||||
format!("k should be > 1, k=[{}]", parameters.k)
|
||||
);
|
||||
if parameters.k <= 1 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"k should be > 1, k=[{}]",
|
||||
parameters.k
|
||||
)));
|
||||
}
|
||||
|
||||
KNNClassifier {
|
||||
Ok(KNNClassifier {
|
||||
classes: classes,
|
||||
y: yi,
|
||||
k: parameters.k,
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance),
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance)?,
|
||||
weight: parameters.weight,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
row_iter(x)
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)]));
|
||||
for (i, x) in row_iter(x).enumerate() {
|
||||
result.set(0, i, self.classes[self.predict_for_row(x)?]);
|
||||
}
|
||||
|
||||
result.to_row_vector()
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
fn predict_for_row(&self, x: Vec<T>) -> usize {
|
||||
let search_result = self.knn_algorithm.find(&x, self.k);
|
||||
fn predict_for_row(&self, x: Vec<T>) -> Result<usize, Failed> {
|
||||
let search_result = self.knn_algorithm.find(&x, self.k)?;
|
||||
|
||||
let weights = self
|
||||
.weight
|
||||
@@ -176,7 +178,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
max_i
|
||||
Ok(max_i)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,8 +193,8 @@ mod tests {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
let y_hat = knn.predict(&x);
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
assert_eq!(y.to_vec(), y_hat);
|
||||
}
|
||||
@@ -210,8 +212,9 @@ mod tests {
|
||||
algorithm: KNNAlgorithmName::LinearSearch,
|
||||
weight: KNNWeightFunction::Distance,
|
||||
},
|
||||
);
|
||||
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]]));
|
||||
)
|
||||
.unwrap();
|
||||
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]])).unwrap();
|
||||
assert_eq!(vec![3.0], y_hat);
|
||||
}
|
||||
|
||||
@@ -221,7 +224,7 @@ mod tests {
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user