format
This commit is contained in:
@@ -212,7 +212,9 @@ mod tests_fastpair {
|
|||||||
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
||||||
|
|
||||||
/// Brute force algorithm, used only for comparison and testing
|
/// Brute force algorithm, used only for comparison and testing
|
||||||
pub fn closest_pair_brute(fastpair: &FastPair<'_, f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
|
pub fn closest_pair_brute(
|
||||||
|
fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
|
||||||
|
) -> PairwiseDistance<f64> {
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
let m = fastpair.samples.shape().0;
|
let m = fastpair.samples.shape().0;
|
||||||
|
|
||||||
|
|||||||
@@ -579,9 +579,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
|
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
|
||||||
for DenseMatrixMutView<'_, T>
|
|
||||||
{
|
|
||||||
fn set(&mut self, pos: (usize, usize), x: T) {
|
fn set(&mut self, pos: (usize, usize), x: T) {
|
||||||
if self.column_major {
|
if self.column_major {
|
||||||
self.values[pos.0 + pos.1 * self.stride] = x;
|
self.values[pos.0 + pos.1 * self.stride] = x;
|
||||||
|
|||||||
@@ -146,9 +146,7 @@ impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
|||||||
|
|
||||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}
|
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}
|
||||||
|
|
||||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
|
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
|
||||||
for ArrayViewMut<'_, T, Ix2>
|
|
||||||
{
|
|
||||||
fn get(&self, pos: (usize, usize)) -> &T {
|
fn get(&self, pos: (usize, usize)) -> &T {
|
||||||
&self[[pos.0, pos.1]]
|
&self[[pos.0, pos.1]]
|
||||||
}
|
}
|
||||||
@@ -175,9 +173,7 @@ impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
|
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
|
||||||
for ArrayViewMut<'_, T, Ix2>
|
|
||||||
{
|
|
||||||
fn set(&mut self, pos: (usize, usize), x: T) {
|
fn set(&mut self, pos: (usize, usize), x: T) {
|
||||||
self[[pos.0, pos.1]] = x
|
self[[pos.0, pos.1]] = x
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -172,12 +172,14 @@ where
|
|||||||
T: Number + RealNumber,
|
T: Number + RealNumber,
|
||||||
M: Array2<T>,
|
M: Array2<T>,
|
||||||
{
|
{
|
||||||
columns.first().cloned().map(|output_matrix| columns
|
columns.first().cloned().map(|output_matrix| {
|
||||||
.iter()
|
columns
|
||||||
.skip(1)
|
.iter()
|
||||||
.fold(output_matrix, |current_matrix, new_colum| {
|
.skip(1)
|
||||||
current_matrix.h_stack(new_colum)
|
.fold(output_matrix, |current_matrix, new_colum| {
|
||||||
}))
|
current_matrix.h_stack(new_colum)
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -77,9 +77,9 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
use crate::error::Failed;
|
use crate::error::Failed;
|
||||||
|
use crate::linalg::basic::arrays::MutArray;
|
||||||
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
|
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
|
||||||
use crate::linalg::basic::matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
use crate::linalg::basic::arrays::MutArray;
|
|
||||||
use crate::numbers::basenum::Number;
|
use crate::numbers::basenum::Number;
|
||||||
use crate::rand_custom::get_rng_impl;
|
use crate::rand_custom::get_rng_impl;
|
||||||
|
|
||||||
@@ -890,7 +890,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
importances
|
importances
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Predict class probabilities for the input samples.
|
/// Predict class probabilities for the input samples.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -933,7 +932,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
/// of the input sample belonging to each class.
|
/// of the input sample belonging to each class.
|
||||||
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
|
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
|
||||||
let mut node = 0;
|
let mut node = 0;
|
||||||
|
|
||||||
while let Some(current_node) = self.nodes().get(node) {
|
while let Some(current_node) = self.nodes().get(node) {
|
||||||
if current_node.true_child.is_none() && current_node.false_child.is_none() {
|
if current_node.true_child.is_none() && current_node.false_child.is_none() {
|
||||||
// Leaf node reached
|
// Leaf node reached
|
||||||
@@ -941,17 +940,17 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
probs[current_node.output] = 1.0;
|
probs[current_node.output] = 1.0;
|
||||||
return probs;
|
return probs;
|
||||||
}
|
}
|
||||||
|
|
||||||
let split_feature = current_node.split_feature;
|
let split_feature = current_node.split_feature;
|
||||||
let split_value = current_node.split_value.unwrap_or(f64::NAN);
|
let split_value = current_node.split_value.unwrap_or(f64::NAN);
|
||||||
|
|
||||||
if x.get((row, split_feature)).to_f64().unwrap() <= split_value {
|
if x.get((row, split_feature)).to_f64().unwrap() <= split_value {
|
||||||
node = current_node.true_child.unwrap();
|
node = current_node.true_child.unwrap();
|
||||||
} else {
|
} else {
|
||||||
node = current_node.false_child.unwrap();
|
node = current_node.false_child.unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This should never happen if the tree is properly constructed
|
// This should never happen if the tree is properly constructed
|
||||||
vec![0.0; self.classes().len()]
|
vec![0.0; self.classes().len()]
|
||||||
}
|
}
|
||||||
@@ -960,8 +959,8 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::basic::matrix::DenseMatrix;
|
|
||||||
use crate::linalg::basic::arrays::Array;
|
use crate::linalg::basic::arrays::Array;
|
||||||
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn search_parameters() {
|
fn search_parameters() {
|
||||||
@@ -1020,24 +1019,28 @@ mod tests {
|
|||||||
&[6.9, 3.1, 4.9, 1.5],
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
&[5.5, 2.3, 4.0, 1.3],
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
&[6.5, 2.8, 4.6, 1.5],
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
]).unwrap();
|
])
|
||||||
|
.unwrap();
|
||||||
let y: Vec<usize> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
|
let y: Vec<usize> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
|
||||||
|
|
||||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
|
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||||
let probabilities = tree.predict_proba(&x).unwrap();
|
let probabilities = tree.predict_proba(&x).unwrap();
|
||||||
|
|
||||||
assert_eq!(probabilities.shape(), (10, 2));
|
assert_eq!(probabilities.shape(), (10, 2));
|
||||||
|
|
||||||
for row in 0..10 {
|
for row in 0..10 {
|
||||||
let row_sum: f64 = probabilities.get_row(row).sum();
|
let row_sum: f64 = probabilities.get_row(row).sum();
|
||||||
assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1");
|
assert!(
|
||||||
|
(row_sum - 1.0).abs() < 1e-6,
|
||||||
|
"Row probabilities should sum to 1"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the first 5 samples have higher probability for class 0
|
// Check if the first 5 samples have higher probability for class 0
|
||||||
for i in 0..5 {
|
for i in 0..5 {
|
||||||
assert!(probabilities.get((i, 0)) > probabilities.get((i, 1)));
|
assert!(probabilities.get((i, 0)) > probabilities.get((i, 1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the last 5 samples have higher probability for class 1
|
// Check if the last 5 samples have higher probability for class 1
|
||||||
for i in 5..10 {
|
for i in 5..10 {
|
||||||
assert!(probabilities.get((i, 1)) > probabilities.get((i, 0)));
|
assert!(probabilities.get((i, 1)) > probabilities.get((i, 0)));
|
||||||
|
|||||||
Reference in New Issue
Block a user