feat: adds FitFailedError and PredictFailedError
This commit is contained in:
+28
-14
@@ -43,8 +43,8 @@
|
|||||||
//! &[5.2, 2.7, 3.9, 1.4],
|
//! &[5.2, 2.7, 3.9, 1.4],
|
||||||
//! ]);
|
//! ]);
|
||||||
//!
|
//!
|
||||||
//! let kmeans = KMeans::new(&x, 2, Default::default()); // Fit to data, 2 clusters
|
//! let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); // Fit to data, 2 clusters
|
||||||
//! let y_hat = kmeans.predict(&x); // use the same points for prediction
|
//! let y_hat = kmeans.predict(&x).unwrap(); // use the same points for prediction
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! ## References:
|
//! ## References:
|
||||||
@@ -60,6 +60,7 @@ use std::iter::Sum;
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::{FitFailedError, PredictFailedError};
|
||||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::distance::euclidian::*;
|
use crate::math::distance::euclidian::*;
|
||||||
@@ -117,18 +118,17 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
|||||||
/// * `data` - training instances to cluster
|
/// * `data` - training instances to cluster
|
||||||
/// * `k` - number of clusters
|
/// * `k` - number of clusters
|
||||||
/// * `parameters` - cluster parameters
|
/// * `parameters` - cluster parameters
|
||||||
pub fn new<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans<T> {
|
pub fn fit<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> Result<KMeans<T>, FitFailedError> {
|
||||||
let bbd = BBDTree::new(data);
|
let bbd = BBDTree::new(data);
|
||||||
|
|
||||||
if k < 2 {
|
if k < 2 {
|
||||||
panic!("Invalid number of clusters: {}", k);
|
return Err(FitFailedError::new(&format!("Invalid number of clusters: {}", k)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if parameters.max_iter <= 0 {
|
if parameters.max_iter <= 0 {
|
||||||
panic!(
|
return Err(FitFailedError::new(&format!("Invalid maximum number of iterations: {}",
|
||||||
"Invalid maximum number of iterations: {}",
|
|
||||||
parameters.max_iter
|
parameters.max_iter
|
||||||
);
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let (n, d) = data.shape();
|
let (n, d) = data.shape();
|
||||||
@@ -172,18 +172,18 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
KMeans {
|
Ok(KMeans {
|
||||||
k: k,
|
k: k,
|
||||||
y: y,
|
y: y,
|
||||||
size: size,
|
size: size,
|
||||||
distortion: distortion,
|
distortion: distortion,
|
||||||
centroids: centroids,
|
centroids: centroids,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict clusters for `x`
|
/// Predict clusters for `x`
|
||||||
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
||||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, PredictFailedError> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
let mut result = M::zeros(1, n);
|
let mut result = M::zeros(1, n);
|
||||||
|
|
||||||
@@ -201,7 +201,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
|||||||
result.set(0, i, T::from(best_cluster).unwrap());
|
result.set(0, i, T::from(best_cluster).unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
result.to_row_vector()
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
|
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
|
||||||
@@ -262,6 +262,20 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_k() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1., 2., 3.],
|
||||||
|
&[4., 5., 6.],
|
||||||
|
]);
|
||||||
|
|
||||||
|
println!("{:?}", KMeans::fit(&x, 0, Default::default()));
|
||||||
|
|
||||||
|
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
|
||||||
|
assert_eq!("Invalid number of clusters: 1", KMeans::fit(&x, 1, Default::default()).unwrap_err().to_string());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -287,9 +301,9 @@ mod tests {
|
|||||||
&[5.2, 2.7, 3.9, 1.4],
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let kmeans = KMeans::new(&x, 2, Default::default());
|
let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap();
|
||||||
|
|
||||||
let y = kmeans.predict(&x);
|
let y = kmeans.predict(&x).unwrap();
|
||||||
|
|
||||||
for i in 0..y.len() {
|
for i in 0..y.len() {
|
||||||
assert_eq!(y[i] as usize, kmeans.y[i]);
|
assert_eq!(y[i] as usize, kmeans.y[i]);
|
||||||
@@ -321,7 +335,7 @@ mod tests {
|
|||||||
&[5.2, 2.7, 3.9, 1.4],
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let kmeans = KMeans::new(&x, 2, Default::default());
|
let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap();
|
||||||
|
|
||||||
let deserialized_kmeans: KMeans<f64> =
|
let deserialized_kmeans: KMeans<f64> =
|
||||||
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
|
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
//! # Custom warnings and errors
|
||||||
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
/// Error to be raised when model does not fits data.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct FitFailedError {
|
||||||
|
details: String
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error to be raised when model prediction cannot be calculated.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct PredictFailedError {
|
||||||
|
details: String
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FitFailedError {
|
||||||
|
/// Creates new instance of `FitFailedError`
|
||||||
|
/// * `msg` - description of the error
|
||||||
|
pub fn new(msg: &str) -> FitFailedError {
|
||||||
|
FitFailedError{details: msg.to_string()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for FitFailedError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(f,"{}",self.details)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for FitFailedError {
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
&self.details
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PredictFailedError {
|
||||||
|
/// Creates new instance of `PredictFailedError`
|
||||||
|
/// * `msg` - description of the error
|
||||||
|
pub fn new(msg: &str) -> PredictFailedError {
|
||||||
|
PredictFailedError{details: msg.to_string()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for PredictFailedError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(f,"{}",self.details)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for PredictFailedError {
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
&self.details
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -89,3 +89,4 @@ pub mod neighbors;
|
|||||||
pub(crate) mod optimization;
|
pub(crate) mod optimization;
|
||||||
/// Supervised tree-based learning methods
|
/// Supervised tree-based learning methods
|
||||||
pub mod tree;
|
pub mod tree;
|
||||||
|
pub mod error;
|
||||||
|
|||||||
Reference in New Issue
Block a user