Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
765fab659c | ||
|
|
0df91706f2 | ||
|
|
2e5f88fad8 | ||
|
|
e445f0d558 | ||
|
|
4d5f64c758 | ||
|
|
28c81eb358 | ||
|
|
7f7b2edca0 | ||
|
|
d46b830bcd | ||
|
|
b6fb8191eb | ||
|
|
61db4ebd90 | ||
|
|
2603a1f42b | ||
|
|
663db0334d |
+2
-1
@@ -33,7 +33,8 @@ itertools = "0.10.3"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3"
|
||||
smartcore = { path = ".", features = ["fp_bench"] }
|
||||
criterion = { version = "0.4", default-features = false }
|
||||
serde_json = "1.0"
|
||||
bincode = "1.3.1"
|
||||
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
/// # Hierarchical clustering
|
||||
///
|
||||
/// Implement hierarchical clustering methods:
|
||||
/// * Agglomerative clustering (current)
|
||||
/// * Bisecting K-Means (future)
|
||||
/// * Fastcluster (future)
|
||||
///
|
||||
|
||||
/*
|
||||
class AgglomerativeClustering():
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
n_clusters : int or None, default=2
|
||||
The number of clusters to find. It must be ``None`` if
|
||||
``distance_threshold`` is not ``None``.
|
||||
affinity : str or callable, default='euclidean'
|
||||
If linkage is "ward", only "euclidean" is accepted.
|
||||
linkage : {'ward',}, default='ward'
|
||||
Which linkage criterion to use. The linkage criterion determines which
|
||||
distance to use between sets of observation. The algorithm will merge
|
||||
the pairs of cluster that minimize this criterion.
|
||||
- 'ward' minimizes the variance of the clusters being merged.
|
||||
compute_distances : bool, default=False
|
||||
Computes distances between clusters even if `distance_threshold` is not
|
||||
used. This can be used to make dendrogram visualization, but introduces
|
||||
a computational and memory overhead.
|
||||
"""
|
||||
|
||||
def fit(X):
|
||||
# compute tree
|
||||
# <https://github.com/scikit-learn/scikit-learn/blob/02ebf9e68fe1fc7687d9e1047b9e465ae0fd945e/sklearn/cluster/_agglomerative.py#L172>
|
||||
parents, childern = ward_tree(X, ....)
|
||||
# compute clusters
|
||||
# <https://github.com/scikit-learn/scikit-learn/blob/70c495250fea7fa3c8c1a4631e6ddcddc9f22451/sklearn/cluster/_hierarchical_fast.pyx#L98>
|
||||
labels = _hierarchical.hc_get_heads(parents)
|
||||
# assign cluster numbers
|
||||
self.labels_ = np.searchsorted(np.unique(labels), labels)
|
||||
|
||||
*/
|
||||
|
||||
// implement ward tree
|
||||
// use scipy.cluster.hierarchy.ward
|
||||
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/hierarchy.py#L738>
|
||||
// use linkage
|
||||
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/hierarchy.py#L837>
|
||||
// use nn_chain
|
||||
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/_hierarchy.pyx#L906>
|
||||
|
||||
// implement hc_get_heads
|
||||
|
||||
|
||||
mod tests {
|
||||
// >>> from sklearn.cluster import AgglomerativeClustering
|
||||
// >>> import numpy as np
|
||||
// >>> X = np.array([[1, 2], [1, 4], [1, 0],
|
||||
// ... [4, 2], [4, 4], [4, 0]])
|
||||
// >>> clustering = AgglomerativeClustering().fit(X)
|
||||
// >>> clustering
|
||||
// AgglomerativeClustering()
|
||||
// >>> clustering.labels_
|
||||
// array([1, 1, 1, 0, 0, 0])
|
||||
}
|
||||
@@ -55,7 +55,8 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::{BaseMatrix, Matrix};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::tree::decision_tree_classifier::{
|
||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||
@@ -316,6 +317,37 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
which_max(&result)
|
||||
}
|
||||
|
||||
/// Predict the per-class probabilties for each observation.
|
||||
/// The probability is calculated as the fraction of trees that predicted a given class
|
||||
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
|
||||
let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len());
|
||||
|
||||
let (n, _) = x.shape();
|
||||
|
||||
for i in 0..n {
|
||||
let row_probs = self.predict_probs_for_row(x, i);
|
||||
|
||||
for (j, item) in row_probs.iter().enumerate() {
|
||||
result.set(i, j, *item);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn predict_probs_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> Vec<f64> {
|
||||
let mut result = vec![0; self.classes.len()];
|
||||
|
||||
for tree in self.trees.iter() {
|
||||
result[tree.predict_for_row(x, row)] += 1;
|
||||
}
|
||||
|
||||
result
|
||||
.iter()
|
||||
.map(|n| *n as f64 / self.trees.len() as f64)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||
let class_weight = vec![1.; num_classes];
|
||||
let nrows = y.len();
|
||||
@@ -341,7 +373,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod tests_prob {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::metrics::*;
|
||||
@@ -482,4 +514,69 @@ mod tests {
|
||||
|
||||
assert_eq!(forest, deserialized_forest);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fit_predict_probabilities() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let classifier = RandomForestClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
RandomForestClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
println!("{:?}", classifier.classes);
|
||||
|
||||
let results = classifier.predict_probs(&x).unwrap();
|
||||
println!("{:?}", x.shape());
|
||||
println!("{:?}", results);
|
||||
println!("{:?}", results.shape());
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
DenseMatrix::<f64>::from_array(
|
||||
20,
|
||||
2,
|
||||
&[
|
||||
1.0, 0.0, 0.78, 0.22, 0.95, 0.05, 0.82, 0.18, 1.0, 0.0, 0.92, 0.08, 0.99, 0.01,
|
||||
0.96, 0.04, 0.36, 0.64, 0.33, 0.67, 0.02, 0.98, 0.02, 0.98, 0.0, 1.0, 0.0, 1.0,
|
||||
0.0, 1.0, 0.0, 1.0, 0.03, 0.97, 0.05, 0.95, 0.0, 1.0, 0.02, 0.98
|
||||
]
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+12
-1
@@ -46,8 +46,11 @@ pub trait RealNumber:
|
||||
self * self
|
||||
}
|
||||
|
||||
/// Raw transmutation to u64
|
||||
/// Raw transmutation to u32
|
||||
fn to_f32_bits(self) -> u32;
|
||||
|
||||
/// Raw transmutation to u64
|
||||
fn to_f64_bits(self) -> u64;
|
||||
}
|
||||
|
||||
impl RealNumber for f64 {
|
||||
@@ -89,6 +92,10 @@ impl RealNumber for f64 {
|
||||
fn to_f32_bits(self) -> u32 {
|
||||
self.to_bits() as u32
|
||||
}
|
||||
|
||||
fn to_f64_bits(self) -> u64 {
|
||||
self.to_bits()
|
||||
}
|
||||
}
|
||||
|
||||
impl RealNumber for f32 {
|
||||
@@ -130,6 +137,10 @@ impl RealNumber for f32 {
|
||||
fn to_f32_bits(self) -> u32 {
|
||||
self.to_bits()
|
||||
}
|
||||
|
||||
fn to_f64_bits(self) -> u64 {
|
||||
self.to_bits() as u64
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
+42
-22
@@ -18,6 +18,8 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -42,34 +44,33 @@ impl Precision {
|
||||
);
|
||||
}
|
||||
|
||||
let mut classes = HashSet::new();
|
||||
for i in 0..y_true.len() {
|
||||
classes.insert(y_true.get(i).to_f64_bits());
|
||||
}
|
||||
let classes = classes.len();
|
||||
|
||||
let mut tp = 0;
|
||||
let mut p = 0;
|
||||
let n = y_true.len();
|
||||
for i in 0..n {
|
||||
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
||||
panic!(
|
||||
"Precision can only be applied to binary classification: {}",
|
||||
y_true.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
|
||||
panic!(
|
||||
"Precision can only be applied to binary classification: {}",
|
||||
y_pred.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_pred.get(i) == T::one() {
|
||||
p += 1;
|
||||
|
||||
if y_true.get(i) == T::one() {
|
||||
let mut fp = 0;
|
||||
for i in 0..y_true.len() {
|
||||
if y_pred.get(i) == y_true.get(i) {
|
||||
if classes == 2 {
|
||||
if y_true.get(i) == T::one() {
|
||||
tp += 1;
|
||||
}
|
||||
} else {
|
||||
tp += 1;
|
||||
}
|
||||
} else if classes == 2 {
|
||||
if y_true.get(i) == T::one() {
|
||||
fp += 1;
|
||||
}
|
||||
} else {
|
||||
fp += 1;
|
||||
}
|
||||
}
|
||||
|
||||
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
||||
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fp).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,5 +89,24 @@ mod tests {
|
||||
|
||||
assert!((score1 - 0.5).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||
|
||||
let score3: f64 = Precision {}.get_score(&y_pred, &y_true);
|
||||
assert!((score3 - 0.5).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn precision_multiclass() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
||||
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
||||
|
||||
let score1: f64 = Precision {}.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = Precision {}.get_score(&y_pred, &y_pred);
|
||||
|
||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
+43
-23
@@ -18,6 +18,9 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashSet;
|
||||
use std::convert::TryInto;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -42,34 +45,32 @@ impl Recall {
|
||||
);
|
||||
}
|
||||
|
||||
let mut classes = HashSet::new();
|
||||
for i in 0..y_true.len() {
|
||||
classes.insert(y_true.get(i).to_f64_bits());
|
||||
}
|
||||
let classes: i64 = classes.len().try_into().unwrap();
|
||||
|
||||
let mut tp = 0;
|
||||
let mut p = 0;
|
||||
let n = y_true.len();
|
||||
for i in 0..n {
|
||||
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
||||
panic!(
|
||||
"Recall can only be applied to binary classification: {}",
|
||||
y_true.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
|
||||
panic!(
|
||||
"Recall can only be applied to binary classification: {}",
|
||||
y_pred.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_true.get(i) == T::one() {
|
||||
p += 1;
|
||||
|
||||
if y_pred.get(i) == T::one() {
|
||||
let mut fne = 0;
|
||||
for i in 0..y_true.len() {
|
||||
if y_pred.get(i) == y_true.get(i) {
|
||||
if classes == 2 {
|
||||
if y_true.get(i) == T::one() {
|
||||
tp += 1;
|
||||
}
|
||||
} else {
|
||||
tp += 1;
|
||||
}
|
||||
} else if classes == 2 {
|
||||
if y_true.get(i) != T::one() {
|
||||
fne += 1;
|
||||
}
|
||||
} else {
|
||||
fne += 1;
|
||||
}
|
||||
}
|
||||
|
||||
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
||||
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fne).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,5 +89,24 @@ mod tests {
|
||||
|
||||
assert!((score1 - 0.5).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||
|
||||
let score3: f64 = Recall {}.get_score(&y_pred, &y_true);
|
||||
assert!((score3 - 0.66666666).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn recall_multiclass() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
||||
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
||||
|
||||
let score1: f64 = Recall {}.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = Recall {}.get_score(&y_pred, &y_pred);
|
||||
|
||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,11 @@ use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configure Behaviour of `StandardScaler`.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
|
||||
pub struct StandardScalerParameters {
|
||||
/// Optionaly adjust mean to be zero.
|
||||
@@ -54,6 +58,7 @@ impl Default for StandardScalerParameters {
|
||||
/// deviation of one. This can improve model training for
|
||||
/// scaling sensitive models like neural network or nearest
|
||||
/// neighbors based models.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||
pub struct StandardScaler<T: RealNumber> {
|
||||
means: Vec<T>,
|
||||
@@ -400,5 +405,43 @@ mod tests {
|
||||
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
|
||||
)
|
||||
}
|
||||
|
||||
/// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been
|
||||
/// serialized and deserialized.
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde_fit_for_random_values() {
|
||||
let fitted_scaler = StandardScaler::fit(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]),
|
||||
StandardScalerParameters::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let deserialized_scaler: StandardScaler<f64> =
|
||||
serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
deserialized_scaler.means,
|
||||
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
||||
);
|
||||
|
||||
assert!(
|
||||
&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[&[
|
||||
0.29426447500954,
|
||||
0.16758497615485,
|
||||
0.20820945786863,
|
||||
0.23329718831165
|
||||
],]),
|
||||
0.00000000000001
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user