implemented multiclass for svc (#308)
* implemented multiclass for svc * modified the multiclass svc so it doesnt modify the current api
This commit is contained in:
+375
-61
@@ -58,10 +58,11 @@
|
||||
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
//!
|
||||
//! let knl = Kernels::linear();
|
||||
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl);
|
||||
//! let svc = SVC::fit(&x, &y, params).unwrap();
|
||||
//! let parameters = &SVCParameters::default().with_c(200.0).with_kernel(knl);
|
||||
//! let svc = SVC::fit(&x, &y, parameters).unwrap();
|
||||
//!
|
||||
//! let y_hat = svc.predict(&x).unwrap();
|
||||
//!
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -84,12 +85,194 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
|
||||
use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
use crate::rand_custom::get_rng_impl;
|
||||
use crate::svm::Kernel;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Configuration for a multi-class Support Vector Machine (SVM) classifier.
|
||||
/// This struct holds the indices of the data points relevant to a specific binary
|
||||
/// classification problem within a multi-class context, and the two classes
|
||||
/// being discriminated.
|
||||
struct MultiClassConfig<TY: Number + Ord> {
|
||||
/// The indices of the data points from the original dataset that belong to the two `classes`.
|
||||
indices: Vec<usize>,
|
||||
/// A tuple representing the two classes that this configuration is designed to distinguish.
|
||||
classes: (TY, TY),
|
||||
}
|
||||
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>>
|
||||
for MultiClassSVC<'a, TX, TY, X, Y>
|
||||
{
|
||||
/// Creates a new, empty `MultiClassSVC` instance.
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
classifiers: Option::None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fits the `MultiClassSVC` model to the provided data and parameters.
|
||||
///
|
||||
/// This method delegates the fitting process to the inherent `MultiClassSVC::fit` method.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array).
|
||||
/// * `y` - A reference to the target labels (1D array).
|
||||
/// * `parameters` - A reference to the `SVCParameters` controlling the SVM training.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` indicating success (`Self`) or failure (`Failed`).
|
||||
fn fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
) -> Result<Self, Failed> {
|
||||
MultiClassSVC::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
PredictorBorrow<'a, X, TX> for MultiClassSVC<'a, TX, TY, X, Y>
|
||||
{
|
||||
/// Predicts the class labels for new data points.
|
||||
///
|
||||
/// This method delegates the prediction process to the inherent `MultiClassSVC::predict` method.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array) for which to make predictions.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
|
||||
fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
|
||||
Ok(self.predict(x).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// A multi-class Support Vector Machine (SVM) classifier.
|
||||
///
|
||||
/// This struct implements a multi-class SVM using the "one-vs-one" strategy,
|
||||
/// where a separate binary SVC classifier is trained for every pair of classes.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// * `'a` - Lifetime parameter for borrowed data.
|
||||
/// * `TX` - The numeric type of the input features (must implement `Number` and `RealNumber`).
|
||||
/// * `TY` - The numeric type of the target labels (must implement `Number` and `Ord`).
|
||||
/// * `X` - The type representing the 2D array of input features (e.g., a matrix).
|
||||
/// * `Y` - The type representing the 1D array of target labels (e.g., a vector).
|
||||
pub struct MultiClassSVC<
|
||||
'a,
|
||||
TX: Number + RealNumber,
|
||||
TY: Number + Ord,
|
||||
X: Array2<TX>,
|
||||
Y: Array1<TY>,
|
||||
> {
|
||||
/// An optional vector of binary `SVC` classifiers.
|
||||
classifiers: Option<Vec<SVC<'a, TX, TY, X, Y>>>,
|
||||
}
|
||||
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
MultiClassSVC<'a, TX, TY, X, Y>
|
||||
{
|
||||
/// Fits the `MultiClassSVC` model to the provided data using a one-vs-one strategy.
|
||||
///
|
||||
/// This method identifies all unique classes in the target labels `y` and then
|
||||
/// trains a binary `SVC` for every unique pair of classes. For each pair, it
|
||||
/// extracts the relevant data points and their labels, and then trains a
|
||||
/// specialized `SVC` for that binary classification task.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array).
|
||||
/// * `y` - A reference to the target labels (1D array).
|
||||
/// * `parameters` - A reference to the `SVCParameters` controlling the SVM training for each individual binary classifier.
|
||||
///
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` indicating success (`MultiClassSVC`) or failure (`Failed`).
|
||||
pub fn fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
) -> Result<MultiClassSVC<'a, TX, TY, X, Y>, Failed> {
|
||||
let unique_classes = y.unique();
|
||||
let mut classifiers = Vec::new();
|
||||
// Iterate through all unique pairs of classes (one-vs-one strategy)
|
||||
for i in 0..unique_classes.len() {
|
||||
for j in i..unique_classes.len() {
|
||||
if i == j {
|
||||
continue;
|
||||
}
|
||||
let class0 = unique_classes[j];
|
||||
let class1 = unique_classes[i];
|
||||
|
||||
let mut indices = Vec::new();
|
||||
// Collect indices of data points belonging to the current pair of classes
|
||||
for (index, v) in y.iterator(0).enumerate() {
|
||||
if *v == class0 || *v == class1 {
|
||||
indices.push(index)
|
||||
}
|
||||
}
|
||||
let classes = (class0, class1);
|
||||
let multiclass_config = MultiClassConfig { classes, indices };
|
||||
// Fit a binary SVC for the current pair of classes
|
||||
let svc = SVC::multiclass_fit(x, y, parameters, multiclass_config).unwrap();
|
||||
classifiers.push(svc);
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
classifiers: Some(classifiers),
|
||||
})
|
||||
}
|
||||
|
||||
/// Predicts the class labels for new data points using the trained multi-class SVM.
|
||||
///
|
||||
/// This method uses a "voting" scheme (majority vote) among all the binary
|
||||
/// classifiers to determine the final prediction for each data point.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array) for which to make predictions.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
|
||||
///
|
||||
pub fn predict(&self, x: &X) -> Result<Vec<TX>, Failed> {
|
||||
// Initialize a HashMap for each data point to store votes for each class
|
||||
let mut polls = vec![HashMap::new(); x.shape().0];
|
||||
// Retrieve the trained binary classifiers
|
||||
let classifiers = self.classifiers.as_ref().unwrap();
|
||||
|
||||
// Iterate through each binary classifier
|
||||
for i in 0..classifiers.len() {
|
||||
let svc = classifiers.get(i).unwrap();
|
||||
let predictions = svc.predict(x).unwrap(); // call SVC::predict for each binary classifier
|
||||
|
||||
// For each prediction from the current binary classifier
|
||||
for (j, prediction) in predictions.iter().enumerate() {
|
||||
let prediction = prediction.to_i32().unwrap();
|
||||
let poll = polls.get_mut(j).unwrap(); // Get the poll for the current data point
|
||||
// Increment the vote for the predicted class
|
||||
if let Some(count) = poll.get_mut(&prediction) {
|
||||
*count += 1
|
||||
} else {
|
||||
poll.insert(prediction, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the final prediction for each data point based on majority vote
|
||||
Ok(polls
|
||||
.iter()
|
||||
.map(|v| {
|
||||
// Find the class with the maximum votes for each data point
|
||||
TX::from(*v.iter().max_by_key(|(_, class)| *class).unwrap().0).unwrap()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// SVC Parameters
|
||||
@@ -123,7 +306,7 @@ pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX
|
||||
)]
|
||||
/// Support Vector Classifier
|
||||
pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
||||
classes: Option<Vec<TY>>,
|
||||
classes: Option<(TY, TY)>,
|
||||
instances: Option<Vec<Vec<TX>>>,
|
||||
#[cfg_attr(feature = "serde", serde(skip))]
|
||||
parameters: Option<&'a SVCParameters<TX, TY, X, Y>>,
|
||||
@@ -152,7 +335,9 @@ struct Cache<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1
|
||||
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
indices: Option<Vec<usize>>,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
classes: &'a (TY, TY),
|
||||
svmin: usize,
|
||||
svmax: usize,
|
||||
gmin: TX,
|
||||
@@ -180,12 +365,12 @@ impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
self.tol = tol;
|
||||
self
|
||||
}
|
||||
|
||||
/// The kernel function.
|
||||
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
||||
self.kernel = Some(Box::new(kernel));
|
||||
self
|
||||
}
|
||||
|
||||
/// Seed for the pseudo random number generator.
|
||||
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
|
||||
self.seed = seed;
|
||||
@@ -241,17 +426,98 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array1<TY> + 'a>
|
||||
SVC<'a, TX, TY, X, Y>
|
||||
{
|
||||
/// Fits SVC to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - class labels
|
||||
/// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values.
|
||||
/// Fits a binary Support Vector Classifier (SVC) to the provided data.
|
||||
///
|
||||
/// This is the primary `fit` method for a standalone binary SVC. It expects
|
||||
/// the target labels `y` to contain exactly two unique classes. If more or
|
||||
/// fewer than two classes are found, it returns an error. It then extracts
|
||||
/// these two classes and proceeds to optimize and fit the SVC model.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array) of the training data.
|
||||
/// * `y` - A reference to the target labels (1D array) of the training data. `y` must contain exactly two unique class labels.
|
||||
/// * `parameters` - A reference to the `SVCParameters` controlling the training process.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` which is:
|
||||
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance.
|
||||
/// - `Err(Failed)`: If the number of unique classes in `y` is not exactly two, or if the underlying optimization fails.
|
||||
pub fn fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
let classes = y.unique();
|
||||
// Validate that there are exactly two unique classes in the target labels.
|
||||
if classes.len() != 2 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Incorrect number of classes: {}. A binary SVC requires exactly two classes.",
|
||||
classes.len()
|
||||
)));
|
||||
}
|
||||
let classes = (classes[0], classes[1]);
|
||||
let svc = Self::optimize_and_fit(x, y, parameters, classes, None);
|
||||
svc
|
||||
}
|
||||
|
||||
/// Fits a binary Support Vector Classifier (SVC) specifically for multi-class scenarios.
|
||||
///
|
||||
/// This function is intended to be called by a multi-class strategy (e.g., one-vs-one)
|
||||
/// to train individual binary SVCs. It takes a `MultiClassConfig` which specifies
|
||||
/// the two classes this SVC should discriminate and the subset of data indices
|
||||
/// relevant to these classes. It then delegates the actual optimization and fitting
|
||||
/// to `optimize_and_fit`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array) of the training data.
|
||||
/// * `y` - A reference to the target labels (1D array) of the training data.
|
||||
/// * `parameters` - A reference to the `SVCParameters` controlling the training process (e.g., kernel, C-value, tolerance).
|
||||
/// * `multiclass_config` - A `MultiClassConfig` struct containing:
|
||||
/// - `classes`: A tuple `(class0, class1)` specifying the two classes this SVC should distinguish.
|
||||
/// - `indices`: A `Vec<usize>` containing the indices of the data points in `x` and `y that belong to either `class0` or `class1`.`
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` which is:
|
||||
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance.
|
||||
/// - `Err(Failed)`: If the fitting process encounters an error (e.g., invalid parameters).
|
||||
fn multiclass_fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
multiclass_config: MultiClassConfig<TY>,
|
||||
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
||||
let classes = multiclass_config.classes;
|
||||
let indices = multiclass_config.indices;
|
||||
let svc = Self::optimize_and_fit(x, y, parameters, classes, Some(indices));
|
||||
svc
|
||||
}
|
||||
|
||||
/// Internal function to optimize and fit the Support Vector Classifier.
|
||||
///
|
||||
/// This is the core logic for training a binary SVC. It performs several checks
|
||||
/// (e.g., kernel presence, data shape consistency) and then initializes an
|
||||
/// `Optimizer` to find the support vectors, weights (`w`), and bias (`b`).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array) of the training data.
|
||||
/// * `y` - A reference to the target labels (1D array) of the training data.
|
||||
/// * `parameters` - A reference to the `SVCParameters` defining the SVM model's configuration.
|
||||
/// * `classes` - A tuple `(class0, class1)` representing the two distinct class labels that the SVC will learn to separate.
|
||||
/// * `indices` - An `Option<Vec<usize>>`. If `Some`, it contains the specific indices of data points from `x` and `y` that should be used for training this binary classifier. If `None`, all data points in `x` and `y` are considered.
|
||||
/// # Returns
|
||||
/// A `Result` which is:
|
||||
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new `SVC` instance populated with the learned model components (support vectors, weights, bias).
|
||||
/// - `Err(Failed)`: If any of the validation checks fail (e.g., missing kernel, mismatched data shapes), or if the optimization process fails.
|
||||
fn optimize_and_fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
classes: (TY, TY),
|
||||
indices: Option<Vec<usize>>,
|
||||
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
||||
let (n_samples, _) = x.shape();
|
||||
|
||||
// Validate that a kernel has been defined in the parameters.
|
||||
if parameters.kernel.is_none() {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError,
|
||||
@@ -259,55 +525,39 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
|
||||
));
|
||||
}
|
||||
|
||||
if n != y.shape() {
|
||||
// Validate that the number of samples in X matches the number of labels in Y.
|
||||
if n_samples != y.shape() {
|
||||
return Err(Failed::fit(
|
||||
"Number of rows of X doesn\'t match number of rows of Y",
|
||||
"Number of rows of X doesn't match number of rows of Y",
|
||||
));
|
||||
}
|
||||
|
||||
let classes = y.unique();
|
||||
|
||||
if classes.len() != 2 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Incorrect number of classes: {}",
|
||||
classes.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Make sure class labels are either 1 or -1
|
||||
for e in y.iterator(0) {
|
||||
let y_v = e.to_i32().unwrap();
|
||||
if y_v != -1 && y_v != 1 {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError,
|
||||
"Class labels must be 1 or -1",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let optimizer: Optimizer<'_, TX, TY, X, Y> = Optimizer::new(x, y, parameters);
|
||||
let optimizer: Optimizer<'_, TX, TY, X, Y> =
|
||||
Optimizer::new(x, y, indices, parameters, &classes);
|
||||
|
||||
// Perform the optimization to find the support vectors, weight vector, and bias.
|
||||
// This is where the core SVM algorithm (e.g., SMO) would run.
|
||||
let (support_vectors, weight, b) = optimizer.optimize();
|
||||
|
||||
// Construct and return the fitted SVC model.
|
||||
Ok(SVC::<'a> {
|
||||
classes: Some(classes),
|
||||
instances: Some(support_vectors),
|
||||
parameters: Some(parameters),
|
||||
w: Some(weight),
|
||||
b: Some(b),
|
||||
phantomdata: PhantomData,
|
||||
classes: Some(classes), // Store the two classes the SVC was trained on.
|
||||
instances: Some(support_vectors), // Store the data points that are support vectors.
|
||||
parameters: Some(parameters), // Reference to the parameters used for fitting.
|
||||
w: Some(weight), // The learned weight vector (for linear kernels).
|
||||
b: Some(b), // The learned bias term.
|
||||
phantomdata: PhantomData, // Placeholder for type parameters not directly stored.
|
||||
})
|
||||
}
|
||||
|
||||
/// Predicts estimated class labels from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
|
||||
let mut y_hat: Vec<TX> = self.decision_function(x)?;
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
let cls_idx = match *y_hat.get(i).unwrap() > TX::zero() {
|
||||
false => TX::from(self.classes.as_ref().unwrap()[0]).unwrap(),
|
||||
true => TX::from(self.classes.as_ref().unwrap()[1]).unwrap(),
|
||||
let cls_idx = match *y_hat.get(i) > TX::zero() {
|
||||
false => TX::from(self.classes.as_ref().unwrap().0).unwrap(),
|
||||
true => TX::from(self.classes.as_ref().unwrap().1).unwrap(),
|
||||
};
|
||||
|
||||
y_hat.set(i, cls_idx);
|
||||
@@ -445,14 +695,18 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
fn new(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
indices: Option<Vec<usize>>,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
classes: &'a (TY, TY),
|
||||
) -> Optimizer<'a, TX, TY, X, Y> {
|
||||
let (n, _) = x.shape();
|
||||
|
||||
Optimizer {
|
||||
x,
|
||||
y,
|
||||
indices,
|
||||
parameters,
|
||||
classes,
|
||||
svmin: 0,
|
||||
svmax: 0,
|
||||
gmin: <TX as Bounded>::max_value(),
|
||||
@@ -478,7 +732,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
for i in self.permutate(n) {
|
||||
x.clear();
|
||||
x.extend(self.x.get_row(i).iterator(0).take(n).copied());
|
||||
self.process(i, &x, *self.y.get(i), &mut cache);
|
||||
let y = if *self.y.get(i) == self.classes.1 {
|
||||
1
|
||||
} else {
|
||||
-1
|
||||
} as f64;
|
||||
self.process(i, &x, y, &mut cache);
|
||||
loop {
|
||||
self.reprocess(tol, &mut cache);
|
||||
self.find_min_max_gradient();
|
||||
@@ -514,14 +773,16 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
for i in self.permutate(n) {
|
||||
x.clear();
|
||||
x.extend(self.x.get_row(i).iterator(0).take(n).copied());
|
||||
if *self.y.get(i) == TY::one() && cp < few {
|
||||
if self.process(i, &x, *self.y.get(i), cache) {
|
||||
let y = if *self.y.get(i) == self.classes.1 {
|
||||
1
|
||||
} else {
|
||||
-1
|
||||
} as f64;
|
||||
if y == 1.0 && cp < few {
|
||||
if self.process(i, &x, y, cache) {
|
||||
cp += 1;
|
||||
}
|
||||
} else if *self.y.get(i) == TY::from(-1).unwrap()
|
||||
&& cn < few
|
||||
&& self.process(i, &x, *self.y.get(i), cache)
|
||||
{
|
||||
} else if y == -1.0 && cn < few && self.process(i, &x, y, cache) {
|
||||
cn += 1;
|
||||
}
|
||||
|
||||
@@ -531,14 +792,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
}
|
||||
}
|
||||
|
||||
fn process(&mut self, i: usize, x: &[TX], y: TY, cache: &mut Cache<TX, TY, X, Y>) -> bool {
|
||||
fn process(&mut self, i: usize, x: &[TX], y: f64, cache: &mut Cache<TX, TY, X, Y>) -> bool {
|
||||
for j in 0..self.sv.len() {
|
||||
if self.sv[j].index == i {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
let mut g: f64 = y.to_f64().unwrap();
|
||||
let mut g = y;
|
||||
|
||||
let mut cache_values: Vec<((usize, usize), TX)> = Vec::new();
|
||||
|
||||
@@ -559,8 +820,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
self.find_min_max_gradient();
|
||||
|
||||
if self.gmin < self.gmax
|
||||
&& ((y > TY::zero() && g < self.gmin.to_f64().unwrap())
|
||||
|| (y < TY::zero() && g > self.gmax.to_f64().unwrap()))
|
||||
&& ((y > 0.0 && g < self.gmin.to_f64().unwrap())
|
||||
|| (y < 0.0 && g > self.gmax.to_f64().unwrap()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -590,7 +851,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
),
|
||||
);
|
||||
|
||||
if y > TY::zero() {
|
||||
if y > 0.0 {
|
||||
self.smo(None, Some(0), TX::zero(), cache);
|
||||
} else {
|
||||
self.smo(Some(0), None, TX::zero(), cache);
|
||||
@@ -647,7 +908,6 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
let gmin = self.gmin;
|
||||
|
||||
let mut idxs_to_drop: HashSet<usize> = HashSet::new();
|
||||
|
||||
self.sv.retain(|v| {
|
||||
if v.alpha == 0f64
|
||||
&& ((TX::from(v.grad).unwrap() >= gmax && TX::zero() >= TX::from(v.cmax).unwrap())
|
||||
@@ -666,7 +926,11 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
|
||||
fn permutate(&self, n: usize) -> Vec<usize> {
|
||||
let mut rng = get_rng_impl(self.parameters.seed);
|
||||
let mut range: Vec<usize> = (0..n).collect();
|
||||
let mut range = if let Some(indices) = self.indices.clone() {
|
||||
indices
|
||||
} else {
|
||||
(0..n).collect::<Vec<usize>>()
|
||||
};
|
||||
range.shuffle(&mut rng);
|
||||
range
|
||||
}
|
||||
@@ -965,12 +1229,12 @@ mod tests {
|
||||
];
|
||||
|
||||
let knl = Kernels::linear();
|
||||
let params = SVCParameters::default()
|
||||
let parameters = SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(knl)
|
||||
.with_seed(Some(100));
|
||||
|
||||
let y_hat = SVC::fit(&x, &y, ¶ms)
|
||||
let y_hat = SVC::fit(&x, &y, ¶meters)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
|
||||
@@ -1070,6 +1334,56 @@ mod tests {
|
||||
assert!(acc >= 0.9, "accuracy ({acc}) is not larger or equal to 0.9");
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn svc_multiclass_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2];
|
||||
|
||||
let knl = Kernels::linear();
|
||||
let parameters = SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(knl)
|
||||
.with_seed(Some(100));
|
||||
|
||||
let y_hat = MultiClassSVC::fit(&x, &y, ¶meters)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
|
||||
|
||||
assert!(
|
||||
acc >= 0.9,
|
||||
"Multiclass accuracy ({acc}) is not larger or equal to 0.9"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
@@ -1106,8 +1420,8 @@ mod tests {
|
||||
];
|
||||
|
||||
let knl = Kernels::linear();
|
||||
let params = SVCParameters::default().with_kernel(knl);
|
||||
let svc = SVC::fit(&x, &y, ¶ms).unwrap();
|
||||
let parameters = SVCParameters::default().with_kernel(knl);
|
||||
let svc = SVC::fit(&x, &y, ¶meters).unwrap();
|
||||
|
||||
// serialization
|
||||
let deserialized_svc: SVC<'_, f64, i32, _, _> =
|
||||
|
||||
Reference in New Issue
Block a user