diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index ecd13bc..3b2d2da 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -43,8 +43,8 @@ //! &[5.2, 2.7, 3.9, 1.4], //! ]); //! -//! let kmeans = KMeans::new(&x, 2, Default::default()); // Fit to data, 2 clusters -//! let y_hat = kmeans.predict(&x); // use the same points for prediction +//! let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); // Fit to data, 2 clusters +//! let y_hat = kmeans.predict(&x).unwrap(); // use the same points for prediction //! ``` //! //! ## References: @@ -60,6 +60,7 @@ use std::iter::Sum; use serde::{Deserialize, Serialize}; +use crate::error::{FitFailedError, PredictFailedError}; use crate::algorithm::neighbour::bbd_tree::BBDTree; use crate::linalg::Matrix; use crate::math::distance::euclidian::*; @@ -117,18 +118,17 @@ impl KMeans { /// * `data` - training instances to cluster /// * `k` - number of clusters /// * `parameters` - cluster parameters - pub fn new>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans { + pub fn fit>(data: &M, k: usize, parameters: KMeansParameters) -> Result, FitFailedError> { let bbd = BBDTree::new(data); if k < 2 { - panic!("Invalid number of clusters: {}", k); + return Err(FitFailedError::new(&format!("Invalid number of clusters: {}", k))); } if parameters.max_iter <= 0 { - panic!( - "Invalid maximum number of iterations: {}", + return Err(FitFailedError::new(&format!("Invalid maximum number of iterations: {}", parameters.max_iter - ); + ))); } let (n, d) = data.shape(); @@ -172,18 +172,18 @@ impl KMeans { } } - KMeans { + Ok(KMeans { k: k, y: y, size: size, distortion: distortion, centroids: centroids, - } + }) } /// Predict clusters for `x` /// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features. - pub fn predict>(&self, x: &M) -> M::RowVector { + pub fn predict>(&self, x: &M) -> Result { let (n, _) = x.shape(); let mut result = M::zeros(1, n); @@ -201,7 +201,7 @@ impl KMeans { result.set(0, i, T::from(best_cluster).unwrap()); } - result.to_row_vector() + Ok(result.to_row_vector()) } fn kmeans_plus_plus>(data: &M, k: usize) -> Vec { @@ -262,6 +262,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn invalid_k() { + let x = DenseMatrix::from_2d_array(&[ + &[1., 2., 3.], + &[4., 5., 6.], + ]); + + println!("{:?}", KMeans::fit(&x, 0, Default::default())); + + assert!(KMeans::fit(&x, 0, Default::default()).is_err()); + assert_eq!("Invalid number of clusters: 1", KMeans::fit(&x, 1, Default::default()).unwrap_err().to_string()); + + } + #[test] fn fit_predict_iris() { let x = DenseMatrix::from_2d_array(&[ @@ -287,9 +301,9 @@ mod tests { &[5.2, 2.7, 3.9, 1.4], ]); - let kmeans = KMeans::new(&x, 2, Default::default()); + let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); - let y = kmeans.predict(&x); + let y = kmeans.predict(&x).unwrap(); for i in 0..y.len() { assert_eq!(y[i] as usize, kmeans.y[i]); @@ -321,7 +335,7 @@ mod tests { &[5.2, 2.7, 3.9, 1.4], ]); - let kmeans = KMeans::new(&x, 2, Default::default()); + let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); let deserialized_kmeans: KMeans = serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap(); diff --git a/src/error/mod.rs b/src/error/mod.rs new file mode 100644 index 0000000..0c6d44f --- /dev/null +++ b/src/error/mod.rs @@ -0,0 +1,55 @@ +//! # Custom warnings and errors +use std::error::Error; +use std::fmt; + +/// Error to be raised when model does not fits data. +#[derive(Debug)] +pub struct FitFailedError { + details: String +} + +/// Error to be raised when model prediction cannot be calculated. +#[derive(Debug)] +pub struct PredictFailedError { + details: String +} + +impl FitFailedError { + /// Creates new instance of `FitFailedError` + /// * `msg` - description of the error + pub fn new(msg: &str) -> FitFailedError { + FitFailedError{details: msg.to_string()} + } +} + +impl fmt::Display for FitFailedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f,"{}",self.details) + } +} + +impl Error for FitFailedError { + fn description(&self) -> &str { + &self.details + } +} + +impl PredictFailedError { + /// Creates new instance of `PredictFailedError` + /// * `msg` - description of the error + pub fn new(msg: &str) -> PredictFailedError { + PredictFailedError{details: msg.to_string()} + } +} + +impl fmt::Display for PredictFailedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f,"{}",self.details) + } +} + +impl Error for PredictFailedError { + fn description(&self) -> &str { + &self.details + } +} diff --git a/src/lib.rs b/src/lib.rs index c21b989..aa2d7c7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -89,3 +89,4 @@ pub mod neighbors; pub(crate) mod optimization; /// Supervised tree-based learning methods pub mod tree; +pub mod error;