Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
20ca5c9647 | ||
|
|
3fe916988f |
+2
-3
@@ -33,8 +33,7 @@ itertools = "0.10.3"
|
|||||||
getrandom = { version = "0.2", features = ["js"] }
|
getrandom = { version = "0.2", features = ["js"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
smartcore = { path = ".", features = ["fp_bench"] }
|
criterion = "0.3"
|
||||||
criterion = { version = "0.4", default-features = false }
|
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
bincode = "1.3.1"
|
bincode = "1.3.1"
|
||||||
|
|
||||||
@@ -53,4 +52,4 @@ required-features = ["ndarray-bindings", "nalgebra-bindings"]
|
|||||||
[[bench]]
|
[[bench]]
|
||||||
name = "fastpair"
|
name = "fastpair"
|
||||||
harness = false
|
harness = false
|
||||||
required-features = ["fp_bench"]
|
required-features = ["fp_bench"]
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
/// # Hierarchical clustering
|
||||||
|
///
|
||||||
|
/// Implement hierarchical clustering methods:
|
||||||
|
/// * Agglomerative clustering (current)
|
||||||
|
/// * Bisecting K-Means (future)
|
||||||
|
/// * Fastcluster (future)
|
||||||
|
///
|
||||||
|
|
||||||
|
/*
|
||||||
|
class AgglomerativeClustering():
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_clusters : int or None, default=2
|
||||||
|
The number of clusters to find. It must be ``None`` if
|
||||||
|
``distance_threshold`` is not ``None``.
|
||||||
|
affinity : str or callable, default='euclidean'
|
||||||
|
If linkage is "ward", only "euclidean" is accepted.
|
||||||
|
linkage : {'ward',}, default='ward'
|
||||||
|
Which linkage criterion to use. The linkage criterion determines which
|
||||||
|
distance to use between sets of observation. The algorithm will merge
|
||||||
|
the pairs of cluster that minimize this criterion.
|
||||||
|
- 'ward' minimizes the variance of the clusters being merged.
|
||||||
|
compute_distances : bool, default=False
|
||||||
|
Computes distances between clusters even if `distance_threshold` is not
|
||||||
|
used. This can be used to make dendrogram visualization, but introduces
|
||||||
|
a computational and memory overhead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fit(X):
|
||||||
|
# compute tree
|
||||||
|
# <https://github.com/scikit-learn/scikit-learn/blob/02ebf9e68fe1fc7687d9e1047b9e465ae0fd945e/sklearn/cluster/_agglomerative.py#L172>
|
||||||
|
parents, childern = ward_tree(X, ....)
|
||||||
|
# compute clusters
|
||||||
|
# <https://github.com/scikit-learn/scikit-learn/blob/70c495250fea7fa3c8c1a4631e6ddcddc9f22451/sklearn/cluster/_hierarchical_fast.pyx#L98>
|
||||||
|
labels = _hierarchical.hc_get_heads(parents)
|
||||||
|
# assign cluster numbers
|
||||||
|
self.labels_ = np.searchsorted(np.unique(labels), labels)
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
// implement ward tree
|
||||||
|
// use scipy.cluster.hierarchy.ward
|
||||||
|
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/hierarchy.py#L738>
|
||||||
|
// use linkage
|
||||||
|
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/hierarchy.py#L837>
|
||||||
|
// use nn_chain
|
||||||
|
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/_hierarchy.pyx#L906>
|
||||||
|
|
||||||
|
// implement hc_get_heads
|
||||||
|
|
||||||
|
|
||||||
|
mod tests {
|
||||||
|
// >>> from sklearn.cluster import AgglomerativeClustering
|
||||||
|
// >>> import numpy as np
|
||||||
|
// >>> X = np.array([[1, 2], [1, 4], [1, 0],
|
||||||
|
// ... [4, 2], [4, 4], [4, 0]])
|
||||||
|
// >>> clustering = AgglomerativeClustering().fit(X)
|
||||||
|
// >>> clustering
|
||||||
|
// AgglomerativeClustering()
|
||||||
|
// >>> clustering.labels_
|
||||||
|
// array([1, 1, 1, 0, 0, 0])
|
||||||
|
}
|
||||||
@@ -55,8 +55,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
use crate::error::{Failed, FailedError};
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::linalg::{BaseMatrix, Matrix};
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::tree::decision_tree_classifier::{
|
use crate::tree::decision_tree_classifier::{
|
||||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||||
@@ -317,37 +316,6 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
which_max(&result)
|
which_max(&result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict the per-class probabilties for each observation.
|
|
||||||
/// The probability is calculated as the fraction of trees that predicted a given class
|
|
||||||
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
|
|
||||||
let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len());
|
|
||||||
|
|
||||||
let (n, _) = x.shape();
|
|
||||||
|
|
||||||
for i in 0..n {
|
|
||||||
let row_probs = self.predict_probs_for_row(x, i);
|
|
||||||
|
|
||||||
for (j, item) in row_probs.iter().enumerate() {
|
|
||||||
result.set(i, j, *item);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn predict_probs_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> Vec<f64> {
|
|
||||||
let mut result = vec![0; self.classes.len()];
|
|
||||||
|
|
||||||
for tree in self.trees.iter() {
|
|
||||||
result[tree.predict_for_row(x, row)] += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
result
|
|
||||||
.iter()
|
|
||||||
.map(|n| *n as f64 / self.trees.len() as f64)
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||||
let class_weight = vec![1.; num_classes];
|
let class_weight = vec![1.; num_classes];
|
||||||
let nrows = y.len();
|
let nrows = y.len();
|
||||||
@@ -373,7 +341,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests_prob {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
use crate::metrics::*;
|
use crate::metrics::*;
|
||||||
@@ -514,69 +482,4 @@ mod tests_prob {
|
|||||||
|
|
||||||
assert_eq!(forest, deserialized_forest);
|
assert_eq!(forest, deserialized_forest);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn fit_predict_probabilities() {
|
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
|
||||||
&[4.7, 3.2, 1.3, 0.2],
|
|
||||||
&[4.6, 3.1, 1.5, 0.2],
|
|
||||||
&[5.0, 3.6, 1.4, 0.2],
|
|
||||||
&[5.4, 3.9, 1.7, 0.4],
|
|
||||||
&[4.6, 3.4, 1.4, 0.3],
|
|
||||||
&[5.0, 3.4, 1.5, 0.2],
|
|
||||||
&[4.4, 2.9, 1.4, 0.2],
|
|
||||||
&[4.9, 3.1, 1.5, 0.1],
|
|
||||||
&[7.0, 3.2, 4.7, 1.4],
|
|
||||||
&[6.4, 3.2, 4.5, 1.5],
|
|
||||||
&[6.9, 3.1, 4.9, 1.5],
|
|
||||||
&[5.5, 2.3, 4.0, 1.3],
|
|
||||||
&[6.5, 2.8, 4.6, 1.5],
|
|
||||||
&[5.7, 2.8, 4.5, 1.3],
|
|
||||||
&[6.3, 3.3, 4.7, 1.6],
|
|
||||||
&[4.9, 2.4, 3.3, 1.0],
|
|
||||||
&[6.6, 2.9, 4.6, 1.3],
|
|
||||||
&[5.2, 2.7, 3.9, 1.4],
|
|
||||||
]);
|
|
||||||
let y = vec![
|
|
||||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
|
||||||
];
|
|
||||||
|
|
||||||
let classifier = RandomForestClassifier::fit(
|
|
||||||
&x,
|
|
||||||
&y,
|
|
||||||
RandomForestClassifierParameters {
|
|
||||||
criterion: SplitCriterion::Gini,
|
|
||||||
max_depth: None,
|
|
||||||
min_samples_leaf: 1,
|
|
||||||
min_samples_split: 2,
|
|
||||||
n_trees: 100,
|
|
||||||
m: Option::None,
|
|
||||||
keep_samples: false,
|
|
||||||
seed: 87,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
println!("{:?}", classifier.classes);
|
|
||||||
|
|
||||||
let results = classifier.predict_probs(&x).unwrap();
|
|
||||||
println!("{:?}", x.shape());
|
|
||||||
println!("{:?}", results);
|
|
||||||
println!("{:?}", results.shape());
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
results,
|
|
||||||
DenseMatrix::<f64>::from_array(
|
|
||||||
20,
|
|
||||||
2,
|
|
||||||
&[
|
|
||||||
1.0, 0.0, 0.78, 0.22, 0.95, 0.05, 0.82, 0.18, 1.0, 0.0, 0.92, 0.08, 0.99, 0.01,
|
|
||||||
0.96, 0.04, 0.36, 0.64, 0.33, 0.67, 0.02, 0.98, 0.02, 0.98, 0.0, 1.0, 0.0, 1.0,
|
|
||||||
0.0, 1.0, 0.0, 1.0, 0.03, 0.97, 0.05, 0.95, 0.0, 1.0, 0.02, 0.98
|
|
||||||
]
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-12
@@ -46,11 +46,8 @@ pub trait RealNumber:
|
|||||||
self * self
|
self * self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Raw transmutation to u32
|
|
||||||
fn to_f32_bits(self) -> u32;
|
|
||||||
|
|
||||||
/// Raw transmutation to u64
|
/// Raw transmutation to u64
|
||||||
fn to_f64_bits(self) -> u64;
|
fn to_f32_bits(self) -> u32;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RealNumber for f64 {
|
impl RealNumber for f64 {
|
||||||
@@ -92,10 +89,6 @@ impl RealNumber for f64 {
|
|||||||
fn to_f32_bits(self) -> u32 {
|
fn to_f32_bits(self) -> u32 {
|
||||||
self.to_bits() as u32
|
self.to_bits() as u32
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_f64_bits(self) -> u64 {
|
|
||||||
self.to_bits()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RealNumber for f32 {
|
impl RealNumber for f32 {
|
||||||
@@ -137,10 +130,6 @@ impl RealNumber for f32 {
|
|||||||
fn to_f32_bits(self) -> u32 {
|
fn to_f32_bits(self) -> u32 {
|
||||||
self.to_bits()
|
self.to_bits()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_f64_bits(self) -> u64 {
|
|
||||||
self.to_bits() as u64
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
+22
-42
@@ -18,8 +18,6 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::collections::HashSet;
|
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
@@ -44,33 +42,34 @@ impl Precision {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut classes = HashSet::new();
|
|
||||||
for i in 0..y_true.len() {
|
|
||||||
classes.insert(y_true.get(i).to_f64_bits());
|
|
||||||
}
|
|
||||||
let classes = classes.len();
|
|
||||||
|
|
||||||
let mut tp = 0;
|
let mut tp = 0;
|
||||||
let mut fp = 0;
|
let mut p = 0;
|
||||||
for i in 0..y_true.len() {
|
let n = y_true.len();
|
||||||
if y_pred.get(i) == y_true.get(i) {
|
for i in 0..n {
|
||||||
if classes == 2 {
|
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
||||||
if y_true.get(i) == T::one() {
|
panic!(
|
||||||
tp += 1;
|
"Precision can only be applied to binary classification: {}",
|
||||||
}
|
y_true.get(i)
|
||||||
} else {
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
|
||||||
|
panic!(
|
||||||
|
"Precision can only be applied to binary classification: {}",
|
||||||
|
y_pred.get(i)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if y_pred.get(i) == T::one() {
|
||||||
|
p += 1;
|
||||||
|
|
||||||
|
if y_true.get(i) == T::one() {
|
||||||
tp += 1;
|
tp += 1;
|
||||||
}
|
}
|
||||||
} else if classes == 2 {
|
|
||||||
if y_true.get(i) == T::one() {
|
|
||||||
fp += 1;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fp += 1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fp).unwrap())
|
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,24 +88,5 @@ mod tests {
|
|||||||
|
|
||||||
assert!((score1 - 0.5).abs() < 1e-8);
|
assert!((score1 - 0.5).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
|
|
||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
|
||||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
|
||||||
|
|
||||||
let score3: f64 = Precision {}.get_score(&y_pred, &y_true);
|
|
||||||
assert!((score3 - 0.5).abs() < 1e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
|
||||||
#[test]
|
|
||||||
fn precision_multiclass() {
|
|
||||||
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
|
||||||
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
|
||||||
|
|
||||||
let score1: f64 = Precision {}.get_score(&y_pred, &y_true);
|
|
||||||
let score2: f64 = Precision {}.get_score(&y_pred, &y_pred);
|
|
||||||
|
|
||||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+23
-43
@@ -18,9 +18,6 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::collections::HashSet;
|
|
||||||
use std::convert::TryInto;
|
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
@@ -45,32 +42,34 @@ impl Recall {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut classes = HashSet::new();
|
|
||||||
for i in 0..y_true.len() {
|
|
||||||
classes.insert(y_true.get(i).to_f64_bits());
|
|
||||||
}
|
|
||||||
let classes: i64 = classes.len().try_into().unwrap();
|
|
||||||
|
|
||||||
let mut tp = 0;
|
let mut tp = 0;
|
||||||
let mut fne = 0;
|
let mut p = 0;
|
||||||
for i in 0..y_true.len() {
|
let n = y_true.len();
|
||||||
if y_pred.get(i) == y_true.get(i) {
|
for i in 0..n {
|
||||||
if classes == 2 {
|
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
||||||
if y_true.get(i) == T::one() {
|
panic!(
|
||||||
tp += 1;
|
"Recall can only be applied to binary classification: {}",
|
||||||
}
|
y_true.get(i)
|
||||||
} else {
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
|
||||||
|
panic!(
|
||||||
|
"Recall can only be applied to binary classification: {}",
|
||||||
|
y_pred.get(i)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if y_true.get(i) == T::one() {
|
||||||
|
p += 1;
|
||||||
|
|
||||||
|
if y_pred.get(i) == T::one() {
|
||||||
tp += 1;
|
tp += 1;
|
||||||
}
|
}
|
||||||
} else if classes == 2 {
|
|
||||||
if y_true.get(i) != T::one() {
|
|
||||||
fne += 1;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fne += 1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fne).unwrap())
|
|
||||||
|
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,24 +88,5 @@ mod tests {
|
|||||||
|
|
||||||
assert!((score1 - 0.5).abs() < 1e-8);
|
assert!((score1 - 0.5).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
|
|
||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
|
||||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
|
||||||
|
|
||||||
let score3: f64 = Recall {}.get_score(&y_pred, &y_true);
|
|
||||||
assert!((score3 - 0.66666666).abs() < 1e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
|
||||||
#[test]
|
|
||||||
fn recall_multiclass() {
|
|
||||||
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
|
||||||
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
|
||||||
|
|
||||||
let score1: f64 = Recall {}.get_score(&y_pred, &y_true);
|
|
||||||
let score2: f64 = Recall {}.get_score(&y_pred, &y_pred);
|
|
||||||
|
|
||||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,11 +32,7 @@ use crate::error::{Failed, FailedError};
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
/// Configure Behaviour of `StandardScaler`.
|
/// Configure Behaviour of `StandardScaler`.
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
|
||||||
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
|
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
|
||||||
pub struct StandardScalerParameters {
|
pub struct StandardScalerParameters {
|
||||||
/// Optionaly adjust mean to be zero.
|
/// Optionaly adjust mean to be zero.
|
||||||
@@ -58,7 +54,6 @@ impl Default for StandardScalerParameters {
|
|||||||
/// deviation of one. This can improve model training for
|
/// deviation of one. This can improve model training for
|
||||||
/// scaling sensitive models like neural network or nearest
|
/// scaling sensitive models like neural network or nearest
|
||||||
/// neighbors based models.
|
/// neighbors based models.
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
|
||||||
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||||
pub struct StandardScaler<T: RealNumber> {
|
pub struct StandardScaler<T: RealNumber> {
|
||||||
means: Vec<T>,
|
means: Vec<T>,
|
||||||
@@ -405,43 +400,5 @@ mod tests {
|
|||||||
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
|
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been
|
|
||||||
/// serialized and deserialized.
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
|
||||||
#[test]
|
|
||||||
#[cfg(feature = "serde")]
|
|
||||||
fn serde_fit_for_random_values() {
|
|
||||||
let fitted_scaler = StandardScaler::fit(
|
|
||||||
&DenseMatrix::from_2d_array(&[
|
|
||||||
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
|
||||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
|
||||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
|
||||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
|
||||||
]),
|
|
||||||
StandardScalerParameters::default(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let deserialized_scaler: StandardScaler<f64> =
|
|
||||||
serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
deserialized_scaler.means,
|
|
||||||
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(
|
|
||||||
&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq(
|
|
||||||
&DenseMatrix::from_2d_array(&[&[
|
|
||||||
0.29426447500954,
|
|
||||||
0.16758497615485,
|
|
||||||
0.20820945786863,
|
|
||||||
0.23329718831165
|
|
||||||
],]),
|
|
||||||
0.00000000000001
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user