* Fix #245: return error for NaN in naive bayes * Implement error handling for NaN values in NBayes predict: * general behaviour has been kept unchanged according to original tests in `mod.rs` * aka: error is returned only if all the predicted probabilities are NaN * Add tests * Add test with static values * Add test for numerical stability with numpy
This commit is contained in:
@@ -7,7 +7,6 @@
|
|||||||
clippy::approx_constant
|
clippy::approx_constant
|
||||||
)]
|
)]
|
||||||
#![warn(missing_docs)]
|
#![warn(missing_docs)]
|
||||||
#![warn(rustdoc::missing_doc_code_examples)]
|
|
||||||
|
|
||||||
//! # smartcore
|
//! # smartcore
|
||||||
//!
|
//!
|
||||||
|
|||||||
+472
-35
@@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
|
|||||||
use crate::numbers::basenum::Number;
|
use crate::numbers::basenum::Number;
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{cmp::Ordering, marker::PhantomData};
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
/// Distribution used in the Naive Bayes classifier.
|
/// Distribution used in the Naive Bayes classifier.
|
||||||
pub(crate) trait NBDistribution<X: Number, Y: Number>: Clone {
|
pub(crate) trait NBDistribution<X: Number, Y: Number>: Clone {
|
||||||
@@ -93,42 +93,42 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
|
|||||||
/// Returns a vector of size N with class estimates.
|
/// Returns a vector of size N with class estimates.
|
||||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||||
let y_classes = self.distribution.classes();
|
let y_classes = self.distribution.classes();
|
||||||
let predictions = x
|
|
||||||
.row_iter()
|
if y_classes.is_empty() {
|
||||||
.map(|row| {
|
return Err(Failed::predict("Failed to predict, no classes available"));
|
||||||
y_classes
|
}
|
||||||
.iter()
|
|
||||||
.enumerate()
|
let (rows, _) = x.shape();
|
||||||
.map(|(class_index, class)| {
|
let mut predictions = Vec::with_capacity(rows);
|
||||||
(
|
let mut all_probs_nan = true;
|
||||||
class,
|
|
||||||
self.distribution.log_likelihood(class_index, &row)
|
for row_index in 0..rows {
|
||||||
+ self.distribution.prior(class_index).ln(),
|
let row = x.get_row(row_index);
|
||||||
)
|
let mut max_log_prob = f64::NEG_INFINITY;
|
||||||
})
|
let mut max_class = None;
|
||||||
// For some reason, the max_by method cannot use NaNs for finding the maximum value, it panics.
|
|
||||||
// NaN must be considered as minimum values,
|
for (class_index, class) in y_classes.iter().enumerate() {
|
||||||
// therefore it's like NaNs would not be considered for choosing the maximum value.
|
let log_likelihood = self.distribution.log_likelihood(class_index, &row);
|
||||||
// So we need to handle this case for avoiding panicking by using `Option::unwrap`.
|
let log_prob = log_likelihood + self.distribution.prior(class_index).ln();
|
||||||
.max_by(|(_, p1), (_, p2)| match p1.partial_cmp(p2) {
|
|
||||||
Some(ordering) => ordering,
|
if !log_prob.is_nan() && log_prob > max_log_prob {
|
||||||
None => {
|
max_log_prob = log_prob;
|
||||||
if p1.is_nan() {
|
max_class = Some(*class);
|
||||||
Ordering::Less
|
all_probs_nan = false;
|
||||||
} else if p2.is_nan() {
|
}
|
||||||
Ordering::Greater
|
}
|
||||||
|
|
||||||
|
predictions.push(max_class.unwrap_or(y_classes[0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
if all_probs_nan {
|
||||||
|
Err(Failed::predict(
|
||||||
|
"Failed to predict, all probabilities were NaN",
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
Ordering::Equal
|
Ok(Y::from_vec_slice(&predictions))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
|
||||||
.map(|(prediction, _probability)| *prediction)
|
|
||||||
.ok_or_else(|| Failed::predict("Failed to predict, there is no result"))
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<TY>, Failed>>()?;
|
|
||||||
let y_hat = Y::from_vec_slice(&predictions);
|
|
||||||
Ok(y_hat)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
pub mod bernoulli;
|
pub mod bernoulli;
|
||||||
pub mod categorical;
|
pub mod categorical;
|
||||||
@@ -177,7 +177,7 @@ mod tests {
|
|||||||
Ok(_) => panic!("Should return error in case of empty classes"),
|
Ok(_) => panic!("Should return error in case of empty classes"),
|
||||||
Err(err) => assert_eq!(
|
Err(err) => assert_eq!(
|
||||||
err.to_string(),
|
err.to_string(),
|
||||||
"Predict failed: Failed to predict, there is no result"
|
"Predict failed: Failed to predict, no classes available"
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,4 +193,441 @@ mod tests {
|
|||||||
Err(_) => panic!("Should success in normal case without NaNs"),
|
Err(_) => panic!("Should success in normal case without NaNs"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A simple test distribution using float
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
struct TestDistributionAgain {
|
||||||
|
classes: Vec<u32>,
|
||||||
|
probs: Vec<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NBDistribution<f64, u32> for TestDistributionAgain {
|
||||||
|
fn classes(&self) -> &Vec<u32> {
|
||||||
|
&self.classes
|
||||||
|
}
|
||||||
|
fn prior(&self, class_index: usize) -> f64 {
|
||||||
|
self.probs[class_index]
|
||||||
|
}
|
||||||
|
fn log_likelihood<'a>(
|
||||||
|
&'a self,
|
||||||
|
class_index: usize,
|
||||||
|
_j: &'a Box<dyn ArrayView1<f64> + 'a>,
|
||||||
|
) -> f64 {
|
||||||
|
self.probs[class_index].ln()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestNB = BaseNaiveBayes<f64, u32, DenseMatrix<f64>, Vec<u32>, TestDistributionAgain>;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_empty_classes() {
|
||||||
|
let dist = TestDistributionAgain {
|
||||||
|
classes: vec![],
|
||||||
|
probs: vec![],
|
||||||
|
};
|
||||||
|
let nb = TestNB::fit(dist).unwrap();
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
assert!(nb.predict(&x).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_single_class() {
|
||||||
|
let dist = TestDistributionAgain {
|
||||||
|
classes: vec![1],
|
||||||
|
probs: vec![1.0],
|
||||||
|
};
|
||||||
|
let nb = TestNB::fit(dist).unwrap();
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
let result = nb.predict(&x).unwrap();
|
||||||
|
assert_eq!(result, vec![1, 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_multiple_classes() {
|
||||||
|
let dist = TestDistributionAgain {
|
||||||
|
classes: vec![1, 2, 3],
|
||||||
|
probs: vec![0.2, 0.5, 0.3],
|
||||||
|
};
|
||||||
|
let nb = TestNB::fit(dist).unwrap();
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]).unwrap();
|
||||||
|
let result = nb.predict(&x).unwrap();
|
||||||
|
assert_eq!(result, vec![2, 2, 2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_with_nans() {
|
||||||
|
let dist = TestDistributionAgain {
|
||||||
|
classes: vec![1, 2],
|
||||||
|
probs: vec![f64::NAN, 0.5],
|
||||||
|
};
|
||||||
|
let nb = TestNB::fit(dist).unwrap();
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
let result = nb.predict(&x).unwrap();
|
||||||
|
assert_eq!(result, vec![2, 2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_all_nans() {
|
||||||
|
let dist = TestDistributionAgain {
|
||||||
|
classes: vec![1, 2],
|
||||||
|
probs: vec![f64::NAN, f64::NAN],
|
||||||
|
};
|
||||||
|
let nb = TestNB::fit(dist).unwrap();
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
assert!(nb.predict(&x).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_extreme_probabilities() {
|
||||||
|
let dist = TestDistributionAgain {
|
||||||
|
classes: vec![1, 2],
|
||||||
|
probs: vec![1e-300, 1e-301],
|
||||||
|
};
|
||||||
|
let nb = TestNB::fit(dist).unwrap();
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
let result = nb.predict(&x).unwrap();
|
||||||
|
assert_eq!(result, vec![1, 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_with_infinity() {
|
||||||
|
let dist = TestDistributionAgain {
|
||||||
|
classes: vec![1, 2, 3],
|
||||||
|
probs: vec![f64::INFINITY, 1.0, 2.0],
|
||||||
|
};
|
||||||
|
let nb = TestNB::fit(dist).unwrap();
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
let result = nb.predict(&x).unwrap();
|
||||||
|
assert_eq!(result, vec![1, 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_with_negative_infinity() {
|
||||||
|
let dist = TestDistributionAgain {
|
||||||
|
classes: vec![1, 2, 3],
|
||||||
|
probs: vec![f64::NEG_INFINITY, 1.0, 2.0],
|
||||||
|
};
|
||||||
|
let nb = TestNB::fit(dist).unwrap();
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
let result = nb.predict(&x).unwrap();
|
||||||
|
assert_eq!(result, vec![3, 3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gaussian_naive_bayes_numerical_stability() {
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
struct GaussianTestDistribution {
|
||||||
|
classes: Vec<u32>,
|
||||||
|
means: Vec<Vec<f64>>,
|
||||||
|
variances: Vec<Vec<f64>>,
|
||||||
|
priors: Vec<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NBDistribution<f64, u32> for GaussianTestDistribution {
|
||||||
|
fn classes(&self) -> &Vec<u32> {
|
||||||
|
&self.classes
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prior(&self, class_index: usize) -> f64 {
|
||||||
|
self.priors[class_index]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log_likelihood<'a>(
|
||||||
|
&'a self,
|
||||||
|
class_index: usize,
|
||||||
|
j: &'a Box<dyn ArrayView1<f64> + 'a>,
|
||||||
|
) -> f64 {
|
||||||
|
let means = &self.means[class_index];
|
||||||
|
let variances = &self.variances[class_index];
|
||||||
|
j.iterator(0)
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, &xi)| {
|
||||||
|
let mean = means[i];
|
||||||
|
let var = variances[i] + 1e-9; // Small smoothing for numerical stability
|
||||||
|
let coeff = -0.5 * (2.0 * std::f64::consts::PI * var).ln();
|
||||||
|
let exponent = -(xi - mean).powi(2) / (2.0 * var);
|
||||||
|
coeff + exponent
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn train_distribution(x: &DenseMatrix<f64>, y: &[u32]) -> GaussianTestDistribution {
|
||||||
|
let mut classes: Vec<u32> = y
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect::<std::collections::HashSet<u32>>()
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
|
classes.sort();
|
||||||
|
let n_classes = classes.len();
|
||||||
|
let n_features = x.shape().1;
|
||||||
|
|
||||||
|
let mut means = vec![vec![0.0; n_features]; n_classes];
|
||||||
|
let mut variances = vec![vec![0.0; n_features]; n_classes];
|
||||||
|
let mut class_counts = vec![0; n_classes];
|
||||||
|
|
||||||
|
// Calculate means and count samples per class
|
||||||
|
for (sample, &class) in x.row_iter().zip(y.iter()) {
|
||||||
|
let class_idx = classes.iter().position(|&c| c == class).unwrap();
|
||||||
|
class_counts[class_idx] += 1;
|
||||||
|
for (i, &value) in sample.iterator(0).enumerate() {
|
||||||
|
means[class_idx][i] += value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize means
|
||||||
|
for (class_idx, mean) in means.iter_mut().enumerate() {
|
||||||
|
for value in mean.iter_mut() {
|
||||||
|
*value /= class_counts[class_idx] as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate variances
|
||||||
|
for (sample, &class) in x.row_iter().zip(y.iter()) {
|
||||||
|
let class_idx = classes.iter().position(|&c| c == class).unwrap();
|
||||||
|
for (i, &value) in sample.iterator(0).enumerate() {
|
||||||
|
let diff = value - means[class_idx][i];
|
||||||
|
variances[class_idx][i] += diff * diff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize variances and add small epsilon to avoid zero variance
|
||||||
|
let epsilon = 1e-9;
|
||||||
|
for (class_idx, variance) in variances.iter_mut().enumerate() {
|
||||||
|
for value in variance.iter_mut() {
|
||||||
|
*value = *value / class_counts[class_idx] as f64 + epsilon;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate priors
|
||||||
|
let total_samples = y.len() as f64;
|
||||||
|
let priors: Vec<f64> = class_counts
|
||||||
|
.iter()
|
||||||
|
.map(|&count| count as f64 / total_samples)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
GaussianTestDistribution {
|
||||||
|
classes,
|
||||||
|
means,
|
||||||
|
variances,
|
||||||
|
priors,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestNBGaussian =
|
||||||
|
BaseNaiveBayes<f64, u32, DenseMatrix<f64>, Vec<u32>, GaussianTestDistribution>;
|
||||||
|
|
||||||
|
// Create a constant training dataset
|
||||||
|
let n_samples = 1000;
|
||||||
|
let n_features = 5;
|
||||||
|
let n_classes = 4;
|
||||||
|
|
||||||
|
let mut x_data = Vec::with_capacity(n_samples * n_features);
|
||||||
|
let mut y_data = Vec::with_capacity(n_samples);
|
||||||
|
|
||||||
|
for i in 0..n_samples {
|
||||||
|
for j in 0..n_features {
|
||||||
|
x_data.push((i * j) as f64 % 10.0);
|
||||||
|
}
|
||||||
|
y_data.push((i % n_classes) as u32);
|
||||||
|
}
|
||||||
|
|
||||||
|
let x = DenseMatrix::new(n_samples, n_features, x_data, true).unwrap();
|
||||||
|
let y = y_data;
|
||||||
|
|
||||||
|
// Train the model
|
||||||
|
let dist = train_distribution(&x, &y);
|
||||||
|
let nb = TestNBGaussian::fit(dist).unwrap();
|
||||||
|
|
||||||
|
// Create constant test data
|
||||||
|
let n_test_samples = 100;
|
||||||
|
let mut test_x_data = Vec::with_capacity(n_test_samples * n_features);
|
||||||
|
for i in 0..n_test_samples {
|
||||||
|
for j in 0..n_features {
|
||||||
|
test_x_data.push((i * j * 2) as f64 % 15.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let test_x = DenseMatrix::new(n_test_samples, n_features, test_x_data, true).unwrap();
|
||||||
|
|
||||||
|
// Make predictions
|
||||||
|
let predictions = nb
|
||||||
|
.predict(&test_x)
|
||||||
|
.map_err(|e| format!("Prediction failed: {}", e))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Check numerical stability
|
||||||
|
assert_eq!(
|
||||||
|
predictions.len(),
|
||||||
|
n_test_samples,
|
||||||
|
"Number of predictions should match number of test samples"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check that all predictions are valid class labels
|
||||||
|
for &pred in predictions.iter() {
|
||||||
|
assert!(pred < n_classes as u32, "Predicted class should be valid");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check consistency of predictions
|
||||||
|
let repeated_predictions = nb
|
||||||
|
.predict(&test_x)
|
||||||
|
.map_err(|e| format!("Repeated prediction failed: {}", e))
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
predictions, repeated_predictions,
|
||||||
|
"Predictions should be consistent when repeated"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check extreme values
|
||||||
|
let extreme_x =
|
||||||
|
DenseMatrix::new(2, n_features, vec![f64::MAX; n_features * 2], true).unwrap();
|
||||||
|
let extreme_predictions = nb.predict(&extreme_x);
|
||||||
|
assert!(
|
||||||
|
extreme_predictions.is_err(),
|
||||||
|
"Extreme value input should result in an error"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
extreme_predictions.unwrap_err().to_string(),
|
||||||
|
"Predict failed: Failed to predict, all probabilities were NaN",
|
||||||
|
"Incorrect error message for extreme values"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check for NaN handling
|
||||||
|
let nan_x = DenseMatrix::new(2, n_features, vec![f64::NAN; n_features * 2], true).unwrap();
|
||||||
|
let nan_predictions = nb.predict(&nan_x);
|
||||||
|
assert!(
|
||||||
|
nan_predictions.is_err(),
|
||||||
|
"NaN input should result in an error"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check for very small values
|
||||||
|
let small_x =
|
||||||
|
DenseMatrix::new(2, n_features, vec![f64::MIN_POSITIVE; n_features * 2], true).unwrap();
|
||||||
|
let small_predictions = nb
|
||||||
|
.predict(&small_x)
|
||||||
|
.map_err(|e| format!("Small value prediction failed: {}", e))
|
||||||
|
.unwrap();
|
||||||
|
for &pred in small_predictions.iter() {
|
||||||
|
assert!(
|
||||||
|
pred < n_classes as u32,
|
||||||
|
"Predictions for very small values should be valid"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for values close to zero
|
||||||
|
let near_zero_x =
|
||||||
|
DenseMatrix::new(2, n_features, vec![1e-300; n_features * 2], true).unwrap();
|
||||||
|
let near_zero_predictions = nb
|
||||||
|
.predict(&near_zero_x)
|
||||||
|
.map_err(|e| format!("Near-zero value prediction failed: {}", e))
|
||||||
|
.unwrap();
|
||||||
|
for &pred in near_zero_predictions.iter() {
|
||||||
|
assert!(
|
||||||
|
pred < n_classes as u32,
|
||||||
|
"Predictions for near-zero values should be valid"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("All numerical stability checks passed!");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gaussian_naive_bayes_numerical_stability_random_data() {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct MySimpleRng {
|
||||||
|
state: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MySimpleRng {
|
||||||
|
fn new(seed: u64) -> Self {
|
||||||
|
MySimpleRng { state: seed }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the next u64 in the sequence.
|
||||||
|
fn next_u64(&mut self) -> u64 {
|
||||||
|
// LCG parameters; these are somewhat arbitrary but commonly used.
|
||||||
|
// Feel free to tweak the multiplier/adder etc.
|
||||||
|
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||||
|
self.state
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get an f64 in the range [min, max).
|
||||||
|
fn next_f64(&mut self, min: f64, max: f64) -> f64 {
|
||||||
|
let fraction = (self.next_u64() as f64) / (u64::MAX as f64);
|
||||||
|
min + fraction * (max - min)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a usize in the range [min, max). This floors the floating result.
|
||||||
|
fn gen_range_usize(&mut self, min: usize, max: usize) -> usize {
|
||||||
|
let v = self.next_f64(min as f64, max as f64);
|
||||||
|
// Truncate into the integer range. Because of floating inexactness,
|
||||||
|
// ensure we also clamp.
|
||||||
|
let int_v = v.floor() as isize;
|
||||||
|
// simple clamp to avoid any float rounding out of range
|
||||||
|
let clamped = int_v.max(min as isize).min((max - 1) as isize);
|
||||||
|
clamped as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
use crate::naive_bayes::gaussian::GaussianNB;
|
||||||
|
// We will generate random data in a reproducible way (using a fixed seed).
|
||||||
|
// We will generate random data in a reproducible way:
|
||||||
|
let mut rng = MySimpleRng::new(42);
|
||||||
|
|
||||||
|
let n_samples = 1000;
|
||||||
|
let n_features = 5;
|
||||||
|
let n_classes = 4;
|
||||||
|
|
||||||
|
// Our feature matrix and label vector
|
||||||
|
let mut x_data = Vec::with_capacity(n_samples * n_features);
|
||||||
|
let mut y_data = Vec::with_capacity(n_samples);
|
||||||
|
|
||||||
|
// Fill x_data with random values and y_data with random class labels.
|
||||||
|
for _i in 0..n_samples {
|
||||||
|
for _j in 0..n_features {
|
||||||
|
// We’ll pick random values in [-10, 10).
|
||||||
|
x_data.push(rng.next_f64(-10.0, 10.0));
|
||||||
|
}
|
||||||
|
let class = rng.gen_range_usize(0, n_classes) as u32;
|
||||||
|
y_data.push(class);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create DenseMatrix from x_data
|
||||||
|
let x = DenseMatrix::new(n_samples, n_features, x_data, true).unwrap();
|
||||||
|
|
||||||
|
// Train GaussianNB
|
||||||
|
let gnb = GaussianNB::fit(&x, &y_data, Default::default())
|
||||||
|
.expect("Fitting GaussianNB with random data failed.");
|
||||||
|
|
||||||
|
// Predict on the same training data to verify no numerical instability
|
||||||
|
let predictions = gnb.predict(&x).expect("Prediction on random data failed.");
|
||||||
|
|
||||||
|
// Basic sanity checks
|
||||||
|
assert_eq!(
|
||||||
|
predictions.len(),
|
||||||
|
n_samples,
|
||||||
|
"Prediction size must match n_samples"
|
||||||
|
);
|
||||||
|
for &pred_class in &predictions {
|
||||||
|
assert!(
|
||||||
|
(pred_class as usize) < n_classes,
|
||||||
|
"Predicted class {} is out of range [0..n_classes).",
|
||||||
|
pred_class
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If you want to compare with scikit-learn, you can do something like:
|
||||||
|
// println!("X = {:?}", &x);
|
||||||
|
// println!("Y = {:?}", &y_data);
|
||||||
|
// println!("predictions = {:?}", &predictions);
|
||||||
|
// and then in Python:
|
||||||
|
// import numpy as np
|
||||||
|
// from sklearn.naive_bayes import GaussianNB
|
||||||
|
// X = np.reshape(np.array(x), (1000, 5), order='F')
|
||||||
|
// Y = np.array(y)
|
||||||
|
// gnb = GaussianNB().fit(X, Y)
|
||||||
|
// preds = gnb.predict(X)
|
||||||
|
// expected = np.array(predictions)
|
||||||
|
// assert expected == preds
|
||||||
|
// They should match closely (or exactly) depending on floating rounding.
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user