fix: formatting
This commit is contained in:
+19
-10
@@ -60,8 +60,8 @@ use std::iter::Sum;
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::error::{FitFailedError, PredictFailedError};
|
|
||||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||||
|
use crate::error::{FitFailedError, PredictFailedError};
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::distance::euclidian::*;
|
use crate::math::distance::euclidian::*;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -118,15 +118,23 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
|||||||
/// * `data` - training instances to cluster
|
/// * `data` - training instances to cluster
|
||||||
/// * `k` - number of clusters
|
/// * `k` - number of clusters
|
||||||
/// * `parameters` - cluster parameters
|
/// * `parameters` - cluster parameters
|
||||||
pub fn fit<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> Result<KMeans<T>, FitFailedError> {
|
pub fn fit<M: Matrix<T>>(
|
||||||
|
data: &M,
|
||||||
|
k: usize,
|
||||||
|
parameters: KMeansParameters,
|
||||||
|
) -> Result<KMeans<T>, FitFailedError> {
|
||||||
let bbd = BBDTree::new(data);
|
let bbd = BBDTree::new(data);
|
||||||
|
|
||||||
if k < 2 {
|
if k < 2 {
|
||||||
return Err(FitFailedError::new(&format!("Invalid number of clusters: {}", k)));
|
return Err(FitFailedError::new(&format!(
|
||||||
|
"Invalid number of clusters: {}",
|
||||||
|
k
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if parameters.max_iter <= 0 {
|
if parameters.max_iter <= 0 {
|
||||||
return Err(FitFailedError::new(&format!("Invalid maximum number of iterations: {}",
|
return Err(FitFailedError::new(&format!(
|
||||||
|
"Invalid maximum number of iterations: {}",
|
||||||
parameters.max_iter
|
parameters.max_iter
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
@@ -264,16 +272,17 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn invalid_k() {
|
fn invalid_k() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||||
&[1., 2., 3.],
|
|
||||||
&[4., 5., 6.],
|
|
||||||
]);
|
|
||||||
|
|
||||||
println!("{:?}", KMeans::fit(&x, 0, Default::default()));
|
println!("{:?}", KMeans::fit(&x, 0, Default::default()));
|
||||||
|
|
||||||
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
|
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
|
||||||
assert_eq!("Invalid number of clusters: 1", KMeans::fit(&x, 1, Default::default()).unwrap_err().to_string());
|
assert_eq!(
|
||||||
|
"Invalid number of clusters: 1",
|
||||||
|
KMeans::fit(&x, 1, Default::default())
|
||||||
|
.unwrap_err()
|
||||||
|
.to_string()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
+8
-4
@@ -5,20 +5,22 @@ use std::fmt;
|
|||||||
/// Error to be raised when model does not fits data.
|
/// Error to be raised when model does not fits data.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct FitFailedError {
|
pub struct FitFailedError {
|
||||||
details: String
|
details: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error to be raised when model prediction cannot be calculated.
|
/// Error to be raised when model prediction cannot be calculated.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct PredictFailedError {
|
pub struct PredictFailedError {
|
||||||
details: String
|
details: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FitFailedError {
|
impl FitFailedError {
|
||||||
/// Creates new instance of `FitFailedError`
|
/// Creates new instance of `FitFailedError`
|
||||||
/// * `msg` - description of the error
|
/// * `msg` - description of the error
|
||||||
pub fn new(msg: &str) -> FitFailedError {
|
pub fn new(msg: &str) -> FitFailedError {
|
||||||
FitFailedError{details: msg.to_string()}
|
FitFailedError {
|
||||||
|
details: msg.to_string(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,7 +40,9 @@ impl PredictFailedError {
|
|||||||
/// Creates new instance of `PredictFailedError`
|
/// Creates new instance of `PredictFailedError`
|
||||||
/// * `msg` - description of the error
|
/// * `msg` - description of the error
|
||||||
pub fn new(msg: &str) -> PredictFailedError {
|
pub fn new(msg: &str) -> PredictFailedError {
|
||||||
PredictFailedError{details: msg.to_string()}
|
PredictFailedError {
|
||||||
|
details: msg.to_string(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -75,6 +75,7 @@ pub mod dataset;
|
|||||||
pub mod decomposition;
|
pub mod decomposition;
|
||||||
/// Ensemble methods, including Random Forest classifier and regressor
|
/// Ensemble methods, including Random Forest classifier and regressor
|
||||||
pub mod ensemble;
|
pub mod ensemble;
|
||||||
|
pub mod error;
|
||||||
/// Diverse collection of linear algebra abstractions and methods that power SmartCore algorithms
|
/// Diverse collection of linear algebra abstractions and methods that power SmartCore algorithms
|
||||||
pub mod linalg;
|
pub mod linalg;
|
||||||
/// Supervised classification and regression models that assume linear relationship between dependent and explanatory variables.
|
/// Supervised classification and regression models that assume linear relationship between dependent and explanatory variables.
|
||||||
@@ -89,4 +90,3 @@ pub mod neighbors;
|
|||||||
pub(crate) mod optimization;
|
pub(crate) mod optimization;
|
||||||
/// Supervised tree-based learning methods
|
/// Supervised tree-based learning methods
|
||||||
pub mod tree;
|
pub mod tree;
|
||||||
pub mod error;
|
|
||||||
|
|||||||
Reference in New Issue
Block a user