fix: formatting
This commit is contained in:
+20
-11
@@ -60,8 +60,8 @@ use std::iter::Sum;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::{FitFailedError, PredictFailedError};
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
use crate::error::{FitFailedError, PredictFailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::*;
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -118,15 +118,23 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `k` - number of clusters
|
||||
/// * `parameters` - cluster parameters
|
||||
pub fn fit<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> Result<KMeans<T>, FitFailedError> {
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
data: &M,
|
||||
k: usize,
|
||||
parameters: KMeansParameters,
|
||||
) -> Result<KMeans<T>, FitFailedError> {
|
||||
let bbd = BBDTree::new(data);
|
||||
|
||||
if k < 2 {
|
||||
return Err(FitFailedError::new(&format!("Invalid number of clusters: {}", k)));
|
||||
return Err(FitFailedError::new(&format!(
|
||||
"Invalid number of clusters: {}",
|
||||
k
|
||||
)));
|
||||
}
|
||||
|
||||
if parameters.max_iter <= 0 {
|
||||
return Err(FitFailedError::new(&format!("Invalid maximum number of iterations: {}",
|
||||
return Err(FitFailedError::new(&format!(
|
||||
"Invalid maximum number of iterations: {}",
|
||||
parameters.max_iter
|
||||
)));
|
||||
}
|
||||
@@ -264,16 +272,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn invalid_k() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2., 3.],
|
||||
&[4., 5., 6.],
|
||||
]);
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
|
||||
println!("{:?}", KMeans::fit(&x, 0, Default::default()));
|
||||
|
||||
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
|
||||
assert_eq!("Invalid number of clusters: 1", KMeans::fit(&x, 1, Default::default()).unwrap_err().to_string());
|
||||
|
||||
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
|
||||
assert_eq!(
|
||||
"Invalid number of clusters: 1",
|
||||
KMeans::fit(&x, 1, Default::default())
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user