diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index 3b2d2da..f232631 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -60,8 +60,8 @@ use std::iter::Sum; use serde::{Deserialize, Serialize}; -use crate::error::{FitFailedError, PredictFailedError}; use crate::algorithm::neighbour::bbd_tree::BBDTree; +use crate::error::{FitFailedError, PredictFailedError}; use crate::linalg::Matrix; use crate::math::distance::euclidian::*; use crate::math::num::RealNumber; @@ -118,15 +118,23 @@ impl KMeans { /// * `data` - training instances to cluster /// * `k` - number of clusters /// * `parameters` - cluster parameters - pub fn fit>(data: &M, k: usize, parameters: KMeansParameters) -> Result, FitFailedError> { + pub fn fit>( + data: &M, + k: usize, + parameters: KMeansParameters, + ) -> Result, FitFailedError> { let bbd = BBDTree::new(data); if k < 2 { - return Err(FitFailedError::new(&format!("Invalid number of clusters: {}", k))); + return Err(FitFailedError::new(&format!( + "Invalid number of clusters: {}", + k + ))); } if parameters.max_iter <= 0 { - return Err(FitFailedError::new(&format!("Invalid maximum number of iterations: {}", + return Err(FitFailedError::new(&format!( + "Invalid maximum number of iterations: {}", parameters.max_iter ))); } @@ -264,16 +272,17 @@ mod tests { #[test] fn invalid_k() { - let x = DenseMatrix::from_2d_array(&[ - &[1., 2., 3.], - &[4., 5., 6.], - ]); + let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); println!("{:?}", KMeans::fit(&x, 0, Default::default())); - assert!(KMeans::fit(&x, 0, Default::default()).is_err()); - assert_eq!("Invalid number of clusters: 1", KMeans::fit(&x, 1, Default::default()).unwrap_err().to_string()); - + assert!(KMeans::fit(&x, 0, Default::default()).is_err()); + assert_eq!( + "Invalid number of clusters: 1", + KMeans::fit(&x, 1, Default::default()) + .unwrap_err() + .to_string() + ); } #[test] diff --git a/src/error/mod.rs b/src/error/mod.rs index 0c6d44f..a644a79 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -5,26 +5,28 @@ use std::fmt; /// Error to be raised when model does not fits data. #[derive(Debug)] pub struct FitFailedError { - details: String + details: String, } /// Error to be raised when model prediction cannot be calculated. #[derive(Debug)] pub struct PredictFailedError { - details: String + details: String, } impl FitFailedError { /// Creates new instance of `FitFailedError` /// * `msg` - description of the error pub fn new(msg: &str) -> FitFailedError { - FitFailedError{details: msg.to_string()} + FitFailedError { + details: msg.to_string(), + } } } impl fmt::Display for FitFailedError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f,"{}",self.details) + write!(f, "{}", self.details) } } @@ -38,13 +40,15 @@ impl PredictFailedError { /// Creates new instance of `PredictFailedError` /// * `msg` - description of the error pub fn new(msg: &str) -> PredictFailedError { - PredictFailedError{details: msg.to_string()} + PredictFailedError { + details: msg.to_string(), + } } } impl fmt::Display for PredictFailedError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f,"{}",self.details) + write!(f, "{}", self.details) } } diff --git a/src/lib.rs b/src/lib.rs index aa2d7c7..5bd22ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,6 +75,7 @@ pub mod dataset; pub mod decomposition; /// Ensemble methods, including Random Forest classifier and regressor pub mod ensemble; +pub mod error; /// Diverse collection of linear algebra abstractions and methods that power SmartCore algorithms pub mod linalg; /// Supervised classification and regression models that assume linear relationship between dependent and explanatory variables. @@ -89,4 +90,3 @@ pub mod neighbors; pub(crate) mod optimization; /// Supervised tree-based learning methods pub mod tree; -pub mod error;