feat: refactoring, adds Result to most public API

This commit is contained in:
Volodymyr Orlov
2020-09-18 15:20:32 -07:00
parent 4921ae76f5
commit a9db970195
24 changed files with 389 additions and 298 deletions
+7 -12
View File
@@ -61,7 +61,7 @@ use std::iter::Sum;
use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::bbd_tree::BBDTree;
use crate::error::{FitFailedError, PredictFailedError};
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::distance::euclidian::*;
use crate::math::num::RealNumber;
@@ -122,19 +122,16 @@ impl<T: RealNumber + Sum> KMeans<T> {
data: &M,
k: usize,
parameters: KMeansParameters,
) -> Result<KMeans<T>, FitFailedError> {
) -> Result<KMeans<T>, Failed> {
let bbd = BBDTree::new(data);
if k < 2 {
return Err(FitFailedError::new(&format!(
"Invalid number of clusters: {}",
k
)));
return Err(Failed::fit(&format!("invalid number of clusters: {}", k)));
}
if parameters.max_iter <= 0 {
return Err(FitFailedError::new(&format!(
"Invalid maximum number of iterations: {}",
return Err(Failed::fit(&format!(
"invalid maximum number of iterations: {}",
parameters.max_iter
)));
}
@@ -191,7 +188,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
/// Predict clusters for `x`
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, PredictFailedError> {
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
let mut result = M::zeros(1, n);
@@ -274,11 +271,9 @@ mod tests {
fn invalid_k() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
println!("{:?}", KMeans::fit(&x, 0, Default::default()));
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
assert_eq!(
"Invalid number of clusters: 1",
"Fit failed: invalid number of clusters: 1",
KMeans::fit(&x, 1, Default::default())
.unwrap_err()
.to_string()