feat: serialization/deserialization with Serde

This commit is contained in:
Volodymyr Orlov
2020-03-31 18:19:20 -07:00
parent 1257d2c19b
commit 8bb6013430
8 changed files with 281 additions and 28 deletions
+58 -1
View File
@@ -4,12 +4,14 @@ use rand::Rng;
use std::iter::Sum;
use std::fmt::Debug;
use serde::{Serialize, Deserialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix;
use crate::math::distance::euclidian;
use crate::algorithm::neighbour::bbd_tree::BBDTree;
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct KMeans<T: FloatExt> {
k: usize,
y: Vec<usize>,
@@ -18,6 +20,29 @@ pub struct KMeans<T: FloatExt> {
centroids: Vec<Vec<T>>
}
impl<T: FloatExt> PartialEq for KMeans<T> {
fn eq(&self, other: &Self) -> bool {
if self.k != other.k ||
self.size != other.size ||
self.centroids.len() != other.centroids.len() {
false
} else {
let n_centroids = self.centroids.len();
for i in 0..n_centroids{
if self.centroids[i].len() != other.centroids[i].len(){
return false
}
for j in 0..self.centroids[i].len() {
if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() {
return false
}
}
}
true
}
}
}
#[derive(Debug, Clone)]
pub struct KMeansParameters {
pub max_iter: usize
@@ -210,5 +235,37 @@ mod tests {
}
}
#[test]
fn serde() {
let x = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4]]);
let kmeans = KMeans::new(&x, 2, Default::default());
let deserialized_kmeans: KMeans<f64> = serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
assert_eq!(kmeans, deserialized_kmeans);
}
}