This commit is contained in:
Lorenzo Mec-iS
2025-01-20 15:27:39 +00:00
parent 609f8024bc
commit fc7f2e61d9
5 changed files with 31 additions and 30 deletions
+3 -1
View File
@@ -212,7 +212,9 @@ mod tests_fastpair {
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
/// Brute force algorithm, used only for comparison and testing /// Brute force algorithm, used only for comparison and testing
pub fn closest_pair_brute(fastpair: &FastPair<'_, f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> { pub fn closest_pair_brute(
fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
) -> PairwiseDistance<f64> {
use itertools::Itertools; use itertools::Itertools;
let m = fastpair.samples.shape().0; let m = fastpair.samples.shape().0;
+1 -3
View File
@@ -579,9 +579,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix
} }
} }
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
for DenseMatrixMutView<'_, T>
{
fn set(&mut self, pos: (usize, usize), x: T) { fn set(&mut self, pos: (usize, usize), x: T) {
if self.column_major { if self.column_major {
self.values[pos.0 + pos.1 * self.stride] = x; self.values[pos.0 + pos.1 * self.stride] = x;
+2 -6
View File
@@ -146,9 +146,7 @@ impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2>
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {} impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
for ArrayViewMut<'_, T, Ix2>
{
fn get(&self, pos: (usize, usize)) -> &T { fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]] &self[[pos.0, pos.1]]
} }
@@ -175,9 +173,7 @@ impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
} }
} }
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
for ArrayViewMut<'_, T, Ix2>
{
fn set(&mut self, pos: (usize, usize), x: T) { fn set(&mut self, pos: (usize, usize), x: T) {
self[[pos.0, pos.1]] = x self[[pos.0, pos.1]] = x
} }
+8 -6
View File
@@ -172,12 +172,14 @@ where
T: Number + RealNumber, T: Number + RealNumber,
M: Array2<T>, M: Array2<T>,
{ {
columns.first().cloned().map(|output_matrix| columns columns.first().cloned().map(|output_matrix| {
.iter() columns
.skip(1) .iter()
.fold(output_matrix, |current_matrix, new_colum| { .skip(1)
current_matrix.h_stack(new_colum) .fold(output_matrix, |current_matrix, new_colum| {
})) current_matrix.h_stack(new_colum)
})
})
} }
#[cfg(test)] #[cfg(test)]
+17 -14
View File
@@ -77,9 +77,9 @@ use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed; use crate::error::Failed;
use crate::linalg::basic::arrays::MutArray;
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1}; use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
use crate::linalg::basic::matrix::DenseMatrix; use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::basic::arrays::MutArray;
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl; use crate::rand_custom::get_rng_impl;
@@ -890,7 +890,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
importances importances
} }
/// Predict class probabilities for the input samples. /// Predict class probabilities for the input samples.
/// ///
/// # Arguments /// # Arguments
@@ -933,7 +932,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
/// of the input sample belonging to each class. /// of the input sample belonging to each class.
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> { fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
let mut node = 0; let mut node = 0;
while let Some(current_node) = self.nodes().get(node) { while let Some(current_node) = self.nodes().get(node) {
if current_node.true_child.is_none() && current_node.false_child.is_none() { if current_node.true_child.is_none() && current_node.false_child.is_none() {
// Leaf node reached // Leaf node reached
@@ -941,17 +940,17 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
probs[current_node.output] = 1.0; probs[current_node.output] = 1.0;
return probs; return probs;
} }
let split_feature = current_node.split_feature; let split_feature = current_node.split_feature;
let split_value = current_node.split_value.unwrap_or(f64::NAN); let split_value = current_node.split_value.unwrap_or(f64::NAN);
if x.get((row, split_feature)).to_f64().unwrap() <= split_value { if x.get((row, split_feature)).to_f64().unwrap() <= split_value {
node = current_node.true_child.unwrap(); node = current_node.true_child.unwrap();
} else { } else {
node = current_node.false_child.unwrap(); node = current_node.false_child.unwrap();
} }
} }
// This should never happen if the tree is properly constructed // This should never happen if the tree is properly constructed
vec![0.0; self.classes().len()] vec![0.0; self.classes().len()]
} }
@@ -960,8 +959,8 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::basic::arrays::Array; use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
#[test] #[test]
fn search_parameters() { fn search_parameters() {
@@ -1020,24 +1019,28 @@ mod tests {
&[6.9, 3.1, 4.9, 1.5], &[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3], &[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5], &[6.5, 2.8, 4.6, 1.5],
]).unwrap(); ])
.unwrap();
let y: Vec<usize> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; let y: Vec<usize> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
let probabilities = tree.predict_proba(&x).unwrap(); let probabilities = tree.predict_proba(&x).unwrap();
assert_eq!(probabilities.shape(), (10, 2)); assert_eq!(probabilities.shape(), (10, 2));
for row in 0..10 { for row in 0..10 {
let row_sum: f64 = probabilities.get_row(row).sum(); let row_sum: f64 = probabilities.get_row(row).sum();
assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1"); assert!(
(row_sum - 1.0).abs() < 1e-6,
"Row probabilities should sum to 1"
);
} }
// Check if the first 5 samples have higher probability for class 0 // Check if the first 5 samples have higher probability for class 0
for i in 0..5 { for i in 0..5 {
assert!(probabilities.get((i, 0)) > probabilities.get((i, 1))); assert!(probabilities.get((i, 0)) > probabilities.get((i, 1)));
} }
// Check if the last 5 samples have higher probability for class 1 // Check if the last 5 samples have higher probability for class 1
for i in 5..10 { for i in 5..10 {
assert!(probabilities.get((i, 1)) > probabilities.get((i, 0))); assert!(probabilities.get((i, 1)) > probabilities.get((i, 0)));