Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61db4ebd90 | ||
|
|
2603a1f42b | ||
|
|
663db0334d | ||
|
|
b4a807eb9f | ||
|
|
ff456df0a4 | ||
|
|
322610c7fb | ||
|
|
70df9a8b49 |
+6
-6
@@ -20,12 +20,12 @@ datasets = []
|
||||
|
||||
[dependencies]
|
||||
ndarray = { version = "0.15", optional = true }
|
||||
nalgebra = { version = "0.23.0", optional = true }
|
||||
num-traits = "0.2.12"
|
||||
num = "0.4.0"
|
||||
rand = "0.8.3"
|
||||
rand_distr = "0.4.0"
|
||||
serde = { version = "1.0.115", features = ["derive"], optional = true }
|
||||
nalgebra = { version = "0.31", optional = true }
|
||||
num-traits = "0.2"
|
||||
num = "0.4"
|
||||
rand = "0.8"
|
||||
rand_distr = "0.4"
|
||||
serde = { version = "1", features = ["derive"], optional = true }
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
@@ -55,7 +55,8 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::{BaseMatrix, Matrix};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::tree::decision_tree_classifier::{
|
||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||
@@ -316,6 +317,37 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
which_max(&result)
|
||||
}
|
||||
|
||||
/// Predict the per-class probabilties for each observation.
|
||||
/// The probability is calculated as the fraction of trees that predicted a given class
|
||||
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
|
||||
let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len());
|
||||
|
||||
let (n, _) = x.shape();
|
||||
|
||||
for i in 0..n {
|
||||
let row_probs = self.predict_probs_for_row(x, i);
|
||||
|
||||
for (j, item) in row_probs.iter().enumerate() {
|
||||
result.set(i, j, *item);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn predict_probs_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> Vec<f64> {
|
||||
let mut result = vec![0; self.classes.len()];
|
||||
|
||||
for tree in self.trees.iter() {
|
||||
result[tree.predict_for_row(x, row)] += 1;
|
||||
}
|
||||
|
||||
result
|
||||
.iter()
|
||||
.map(|n| *n as f64 / self.trees.len() as f64)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||
let class_weight = vec![1.; num_classes];
|
||||
let nrows = y.len();
|
||||
@@ -341,7 +373,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod tests_prob {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::metrics::*;
|
||||
@@ -482,4 +514,71 @@ mod tests {
|
||||
|
||||
assert_eq!(forest, deserialized_forest);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_probabilities() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let classifier = RandomForestClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
RandomForestClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
println!("{:?}", classifier.classes);
|
||||
|
||||
let results = classifier.predict_probs(&x).unwrap();
|
||||
println!("{:?}", x.shape());
|
||||
println!("{:?}", results);
|
||||
println!("{:?}", results.shape());
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
DenseMatrix::<f64>::from_array(
|
||||
20,
|
||||
2,
|
||||
&[
|
||||
1.0, 0.78, 0.95, 0.82, 1.0, 0.92, 0.99, 0.96, 0.36, 0.33, 0.02, 0.02, 0.0, 0.0,
|
||||
0.0, 0.0, 0.03, 0.05, 0.0, 0.02, 0.0, 0.22, 0.05, 0.18, 0.0, 0.08, 0.01, 0.04,
|
||||
0.64, 0.67, 0.98, 0.98, 1.0, 1.0, 1.0, 1.0, 0.97, 0.95, 1.0, 0.98
|
||||
]
|
||||
)
|
||||
);
|
||||
assert!(false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
use std::iter::Sum;
|
||||
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
|
||||
|
||||
use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, RowDVector, Scalar, VecStorage, U1};
|
||||
use nalgebra::{Const, DMatrix, Dynamic, Matrix, OMatrix, RowDVector, Scalar, VecStorage, U1};
|
||||
|
||||
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
@@ -53,7 +53,7 @@ use crate::linalg::Matrix as SmartCoreMatrix;
|
||||
use crate::linalg::{BaseMatrix, BaseVector};
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
impl<T: RealNumber + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
|
||||
impl<T: RealNumber + 'static> BaseVector<T> for OMatrix<T, U1, Dynamic> {
|
||||
fn get(&self, i: usize) -> T {
|
||||
*self.get((0, i)).unwrap()
|
||||
}
|
||||
@@ -198,7 +198,7 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
|
||||
|
||||
fn to_row_vector(self) -> Self::RowVector {
|
||||
let (nrows, ncols) = self.shape();
|
||||
self.reshape_generic(U1, Dynamic::new(nrows * ncols))
|
||||
self.reshape_generic(Const::<1>, Dynamic::new(nrows * ncols))
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> T {
|
||||
@@ -955,7 +955,7 @@ mod tests {
|
||||
#[test]
|
||||
fn pow_mut() {
|
||||
let mut a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
||||
a.pow_mut(3.);
|
||||
BaseMatrix::pow_mut(&mut a, 3.);
|
||||
assert_eq!(a, DMatrix::from_row_slice(1, 3, &[1., 8., 27.]));
|
||||
}
|
||||
|
||||
|
||||
+57
-10
@@ -263,21 +263,33 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
|
||||
/// Predicts estimated class labels from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
let mut y_hat = self.decision_function(x)?;
|
||||
|
||||
let mut y_hat = M::RowVector::zeros(n);
|
||||
|
||||
for i in 0..n {
|
||||
let cls_idx = match self.predict_for_row(x.get_row(i)) == T::one() {
|
||||
for i in 0..y_hat.len() {
|
||||
let cls_idx = match y_hat.get(i) > T::zero() {
|
||||
false => self.classes[0],
|
||||
true => self.classes[1],
|
||||
};
|
||||
|
||||
y_hat.set(i, cls_idx);
|
||||
}
|
||||
|
||||
Ok(y_hat)
|
||||
}
|
||||
|
||||
/// Evaluates the decision function for the rows in `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn decision_function(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
let mut y_hat = M::RowVector::zeros(n);
|
||||
|
||||
for i in 0..n {
|
||||
y_hat.set(i, self.predict_for_row(x.get_row(i)));
|
||||
}
|
||||
|
||||
Ok(y_hat)
|
||||
}
|
||||
|
||||
fn predict_for_row(&self, x: M::RowVector) -> T {
|
||||
let mut f = self.b;
|
||||
|
||||
@@ -285,11 +297,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
|
||||
f += self.w[i] * self.kernel.apply(&x, &self.instances[i]);
|
||||
}
|
||||
|
||||
if f > T::zero() {
|
||||
T::one()
|
||||
} else {
|
||||
-T::one()
|
||||
}
|
||||
f
|
||||
}
|
||||
}
|
||||
|
||||
@@ -772,6 +780,45 @@ mod tests {
|
||||
assert!(accuracy(&y_hat, &y) >= 0.9);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svc_fit_decision_function() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]);
|
||||
|
||||
let x2 = DenseMatrix::from_2d_array(&[
|
||||
&[3.0, 3.0],
|
||||
&[4.0, 4.0],
|
||||
&[6.0, 6.0],
|
||||
&[10.0, 10.0],
|
||||
&[1.0, 1.0],
|
||||
&[0.0, 0.0],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1.];
|
||||
|
||||
let y_hat = SVC::fit(
|
||||
&x,
|
||||
&y,
|
||||
SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(Kernels::linear()),
|
||||
)
|
||||
.and_then(|lr| lr.decision_function(&x2))
|
||||
.unwrap();
|
||||
|
||||
// x can be classified by a straight line through [6.0, 0.0] and [0.0, 6.0],
|
||||
// so the score should increase as points get further away from that line
|
||||
println!("{:?}", y_hat);
|
||||
assert!(y_hat[1] < y_hat[2]);
|
||||
assert!(y_hat[2] < y_hat[3]);
|
||||
|
||||
// for negative scores the score should decrease
|
||||
assert!(y_hat[4] > y_hat[5]);
|
||||
|
||||
// y_hat[0] is on the line, so its score should be close to 0
|
||||
assert!(y_hat[0].abs() <= 0.1);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svc_fit_predict_rbf() {
|
||||
|
||||
Reference in New Issue
Block a user