feat: refactoring, adds Result to most public API
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
//!
|
//!
|
||||||
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
|
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
|
||||||
//!
|
//!
|
||||||
//! let mut tree = CoverTree::new(data, SimpleDistance {});
|
//! let mut tree = CoverTree::new(data, SimpleDistance {}).unwrap();
|
||||||
//!
|
//!
|
||||||
//! tree.find(&5, 3); // find 3 knn points from 5
|
//! tree.find(&5, 3); // find 3 knn points from 5
|
||||||
//!
|
//!
|
||||||
@@ -26,6 +26,7 @@ use std::fmt::Debug;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||||
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
@@ -73,7 +74,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
/// Construct a cover tree.
|
/// Construct a cover tree.
|
||||||
/// * `data` - vector of data points to search for.
|
/// * `data` - vector of data points to search for.
|
||||||
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
||||||
pub fn new(data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
|
pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, F, D>, Failed> {
|
||||||
let base = F::from_f64(1.3).unwrap();
|
let base = F::from_f64(1.3).unwrap();
|
||||||
let root = Node {
|
let root = Node {
|
||||||
idx: 0,
|
idx: 0,
|
||||||
@@ -93,19 +94,22 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
|
|
||||||
tree.build_cover_tree();
|
tree.build_cover_tree();
|
||||||
|
|
||||||
tree
|
Ok(tree)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find k nearest neighbors of `p`
|
/// Find k nearest neighbors of `p`
|
||||||
/// * `p` - look for k nearest points to `p`
|
/// * `p` - look for k nearest points to `p`
|
||||||
/// * `k` - the number of nearest neighbors to return
|
/// * `k` - the number of nearest neighbors to return
|
||||||
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> {
|
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
|
||||||
if k <= 0 {
|
if k <= 0 {
|
||||||
panic!("k should be > 0");
|
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if k > self.data.len() {
|
if k > self.data.len() {
|
||||||
panic!("k is > than the dataset size");
|
return Err(Failed::because(
|
||||||
|
FailedError::FindFailed,
|
||||||
|
"k is > than the dataset size",
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let e = self.get_data_value(self.root.idx);
|
let e = self.get_data_value(self.root.idx);
|
||||||
@@ -171,7 +175,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
neighbors.into_iter().take(k).collect()
|
Ok(neighbors.into_iter().take(k).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_leaf(&self, idx: usize) -> Node<F> {
|
fn new_leaf(&self, idx: usize) -> Node<F> {
|
||||||
@@ -407,9 +411,9 @@ mod tests {
|
|||||||
fn cover_tree_test() {
|
fn cover_tree_test() {
|
||||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
|
||||||
let tree = CoverTree::new(data, SimpleDistance {});
|
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
|
||||||
|
|
||||||
let mut knn = tree.find(&5, 3);
|
let mut knn = tree.find(&5, 3).unwrap();
|
||||||
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||||
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
|
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
|
||||||
assert_eq!(vec!(3, 4, 5), knn);
|
assert_eq!(vec!(3, 4, 5), knn);
|
||||||
@@ -425,9 +429,9 @@ mod tests {
|
|||||||
vec![9., 10.],
|
vec![9., 10.],
|
||||||
];
|
];
|
||||||
|
|
||||||
let tree = CoverTree::new(data, Distances::euclidian());
|
let tree = CoverTree::new(data, Distances::euclidian()).unwrap();
|
||||||
|
|
||||||
let mut knn = tree.find(&vec![1., 2.], 3);
|
let mut knn = tree.find(&vec![1., 2.], 3).unwrap();
|
||||||
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||||
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
|
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
|
||||||
|
|
||||||
@@ -438,7 +442,7 @@ mod tests {
|
|||||||
fn serde() {
|
fn serde() {
|
||||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
|
||||||
let tree = CoverTree::new(data, SimpleDistance {});
|
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
|
||||||
|
|
||||||
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
|
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
|
||||||
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();
|
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
//!
|
//!
|
||||||
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
|
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
|
||||||
//!
|
//!
|
||||||
//! let knn = LinearKNNSearch::new(data, SimpleDistance {});
|
//! let knn = LinearKNNSearch::new(data, SimpleDistance {}).unwrap();
|
||||||
//!
|
//!
|
||||||
//! knn.find(&5, 3); // find 3 knn points from 5
|
//! knn.find(&5, 3); // find 3 knn points from 5
|
||||||
//!
|
//!
|
||||||
@@ -26,6 +26,7 @@ use std::cmp::{Ordering, PartialOrd};
|
|||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
@@ -41,18 +42,18 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
/// Initializes algorithm.
|
/// Initializes algorithm.
|
||||||
/// * `data` - vector of data points to search for.
|
/// * `data` - vector of data points to search for.
|
||||||
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
||||||
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D> {
|
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, F, D>, Failed> {
|
||||||
LinearKNNSearch {
|
Ok(LinearKNNSearch {
|
||||||
data: data,
|
data: data,
|
||||||
distance: distance,
|
distance: distance,
|
||||||
f: PhantomData,
|
f: PhantomData,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find k nearest neighbors
|
/// Find k nearest neighbors
|
||||||
/// * `from` - look for k nearest points to `from`
|
/// * `from` - look for k nearest points to `from`
|
||||||
/// * `k` - the number of nearest neighbors to return
|
/// * `k` - the number of nearest neighbors to return
|
||||||
pub fn find(&self, from: &T, k: usize) -> Vec<(usize, F)> {
|
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
|
||||||
if k < 1 || k > self.data.len() {
|
if k < 1 || k > self.data.len() {
|
||||||
panic!("k should be >= 1 and <= length(data)");
|
panic!("k should be >= 1 and <= length(data)");
|
||||||
}
|
}
|
||||||
@@ -76,10 +77,11 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
heap.get()
|
Ok(heap
|
||||||
|
.get()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flat_map(|x| x.index.map(|i| (i, x.distance)))
|
.flat_map(|x| x.index.map(|i| (i, x.distance)))
|
||||||
.collect()
|
.collect())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,9 +122,14 @@ mod tests {
|
|||||||
fn knn_find() {
|
fn knn_find() {
|
||||||
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||||
|
|
||||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
|
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {}).unwrap();
|
||||||
|
|
||||||
let mut found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
|
let mut found_idxs1: Vec<usize> = algorithm1
|
||||||
|
.find(&2, 3)
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.map(|v| v.0)
|
||||||
|
.collect();
|
||||||
found_idxs1.sort();
|
found_idxs1.sort();
|
||||||
|
|
||||||
assert_eq!(vec!(0, 1, 2), found_idxs1);
|
assert_eq!(vec!(0, 1, 2), found_idxs1);
|
||||||
@@ -135,10 +142,11 @@ mod tests {
|
|||||||
vec![5., 5.],
|
vec![5., 5.],
|
||||||
];
|
];
|
||||||
|
|
||||||
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian()).unwrap();
|
||||||
|
|
||||||
let mut found_idxs2: Vec<usize> = algorithm2
|
let mut found_idxs2: Vec<usize> = algorithm2
|
||||||
.find(&vec![3., 3.], 3)
|
.find(&vec![3., 3.], 3)
|
||||||
|
.unwrap()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| v.0)
|
.map(|v| v.0)
|
||||||
.collect();
|
.collect();
|
||||||
|
|||||||
+7
-12
@@ -61,7 +61,7 @@ use std::iter::Sum;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||||
use crate::error::{FitFailedError, PredictFailedError};
|
use crate::error::Failed;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::distance::euclidian::*;
|
use crate::math::distance::euclidian::*;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -122,19 +122,16 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
|||||||
data: &M,
|
data: &M,
|
||||||
k: usize,
|
k: usize,
|
||||||
parameters: KMeansParameters,
|
parameters: KMeansParameters,
|
||||||
) -> Result<KMeans<T>, FitFailedError> {
|
) -> Result<KMeans<T>, Failed> {
|
||||||
let bbd = BBDTree::new(data);
|
let bbd = BBDTree::new(data);
|
||||||
|
|
||||||
if k < 2 {
|
if k < 2 {
|
||||||
return Err(FitFailedError::new(&format!(
|
return Err(Failed::fit(&format!("invalid number of clusters: {}", k)));
|
||||||
"Invalid number of clusters: {}",
|
|
||||||
k
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if parameters.max_iter <= 0 {
|
if parameters.max_iter <= 0 {
|
||||||
return Err(FitFailedError::new(&format!(
|
return Err(Failed::fit(&format!(
|
||||||
"Invalid maximum number of iterations: {}",
|
"invalid maximum number of iterations: {}",
|
||||||
parameters.max_iter
|
parameters.max_iter
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
@@ -191,7 +188,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
|||||||
|
|
||||||
/// Predict clusters for `x`
|
/// Predict clusters for `x`
|
||||||
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
||||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, PredictFailedError> {
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
let mut result = M::zeros(1, n);
|
let mut result = M::zeros(1, n);
|
||||||
|
|
||||||
@@ -274,11 +271,9 @@ mod tests {
|
|||||||
fn invalid_k() {
|
fn invalid_k() {
|
||||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||||
|
|
||||||
println!("{:?}", KMeans::fit(&x, 0, Default::default()));
|
|
||||||
|
|
||||||
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
|
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
"Invalid number of clusters: 1",
|
"Fit failed: invalid number of clusters: 1",
|
||||||
KMeans::fit(&x, 1, Default::default())
|
KMeans::fit(&x, 1, Default::default())
|
||||||
.unwrap_err()
|
.unwrap_err()
|
||||||
.to_string()
|
.to_string()
|
||||||
|
|||||||
+24
-18
@@ -37,9 +37,9 @@
|
|||||||
//! &[5.2, 2.7, 3.9, 1.4],
|
//! &[5.2, 2.7, 3.9, 1.4],
|
||||||
//! ]);
|
//! ]);
|
||||||
//!
|
//!
|
||||||
//! let pca = PCA::new(&iris, 2, Default::default()); // Reduce number of features to 2
|
//! let pca = PCA::fit(&iris, 2, Default::default()).unwrap(); // Reduce number of features to 2
|
||||||
//!
|
//!
|
||||||
//! let iris_reduced = pca.transform(&iris);
|
//! let iris_reduced = pca.transform(&iris).unwrap();
|
||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
@@ -49,6 +49,7 @@ use std::fmt::Debug;
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
@@ -100,7 +101,11 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
|||||||
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||||
/// * `n_components` - number of components to keep.
|
/// * `n_components` - number of components to keep.
|
||||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||||
pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA<T, M> {
|
pub fn fit(
|
||||||
|
data: &M,
|
||||||
|
n_components: usize,
|
||||||
|
parameters: PCAParameters,
|
||||||
|
) -> Result<PCA<T, M>, Failed> {
|
||||||
let (m, n) = data.shape();
|
let (m, n) = data.shape();
|
||||||
|
|
||||||
let mu = data.column_mean();
|
let mu = data.column_mean();
|
||||||
@@ -117,7 +122,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
|||||||
let mut eigenvectors;
|
let mut eigenvectors;
|
||||||
|
|
||||||
if m > n && !parameters.use_correlation_matrix {
|
if m > n && !parameters.use_correlation_matrix {
|
||||||
let svd = x.svd();
|
let svd = x.svd()?;
|
||||||
eigenvalues = svd.s;
|
eigenvalues = svd.s;
|
||||||
for i in 0..eigenvalues.len() {
|
for i in 0..eigenvalues.len() {
|
||||||
eigenvalues[i] = eigenvalues[i] * eigenvalues[i];
|
eigenvalues[i] = eigenvalues[i] * eigenvalues[i];
|
||||||
@@ -155,7 +160,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let evd = cov.evd(true);
|
let evd = cov.evd(true)?;
|
||||||
|
|
||||||
eigenvalues = evd.d;
|
eigenvalues = evd.d;
|
||||||
|
|
||||||
@@ -167,7 +172,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let evd = cov.evd(true);
|
let evd = cov.evd(true)?;
|
||||||
|
|
||||||
eigenvalues = evd.d;
|
eigenvalues = evd.d;
|
||||||
|
|
||||||
@@ -189,26 +194,26 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PCA {
|
Ok(PCA {
|
||||||
eigenvectors: eigenvectors,
|
eigenvectors: eigenvectors,
|
||||||
eigenvalues: eigenvalues,
|
eigenvalues: eigenvalues,
|
||||||
projection: projection.transpose(),
|
projection: projection.transpose(),
|
||||||
mu: mu,
|
mu: mu,
|
||||||
pmu: pmu,
|
pmu: pmu,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run dimensionality reduction for `x`
|
/// Run dimensionality reduction for `x`
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn transform(&self, x: &M) -> M {
|
pub fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||||
let (nrows, ncols) = x.shape();
|
let (nrows, ncols) = x.shape();
|
||||||
let (_, n_components) = self.projection.shape();
|
let (_, n_components) = self.projection.shape();
|
||||||
if ncols != self.mu.len() {
|
if ncols != self.mu.len() {
|
||||||
panic!(
|
return Err(Failed::transform(&format!(
|
||||||
"Invalid input vector size: {}, expected: {}",
|
"Invalid input vector size: {}, expected: {}",
|
||||||
ncols,
|
ncols,
|
||||||
self.mu.len()
|
self.mu.len()
|
||||||
);
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut x_transformed = x.matmul(&self.projection);
|
let mut x_transformed = x.matmul(&self.projection);
|
||||||
@@ -217,7 +222,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
|||||||
x_transformed.sub_element_mut(r, c, self.pmu[c]);
|
x_transformed.sub_element_mut(r, c, self.pmu[c]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
x_transformed
|
Ok(x_transformed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -372,7 +377,7 @@ mod tests {
|
|||||||
302.04806302399646,
|
302.04806302399646,
|
||||||
];
|
];
|
||||||
|
|
||||||
let pca = PCA::new(&us_arrests, 4, Default::default());
|
let pca = PCA::fit(&us_arrests, 4, Default::default()).unwrap();
|
||||||
|
|
||||||
assert!(pca
|
assert!(pca
|
||||||
.eigenvectors
|
.eigenvectors
|
||||||
@@ -383,7 +388,7 @@ mod tests {
|
|||||||
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
|
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
let us_arrests_t = pca.transform(&us_arrests);
|
let us_arrests_t = pca.transform(&us_arrests).unwrap();
|
||||||
|
|
||||||
assert!(us_arrests_t
|
assert!(us_arrests_t
|
||||||
.abs()
|
.abs()
|
||||||
@@ -481,13 +486,14 @@ mod tests {
|
|||||||
0.1734300877298357,
|
0.1734300877298357,
|
||||||
];
|
];
|
||||||
|
|
||||||
let pca = PCA::new(
|
let pca = PCA::fit(
|
||||||
&us_arrests,
|
&us_arrests,
|
||||||
4,
|
4,
|
||||||
PCAParameters {
|
PCAParameters {
|
||||||
use_correlation_matrix: true,
|
use_correlation_matrix: true,
|
||||||
},
|
},
|
||||||
);
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(pca
|
assert!(pca
|
||||||
.eigenvectors
|
.eigenvectors
|
||||||
@@ -498,7 +504,7 @@ mod tests {
|
|||||||
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
|
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
let us_arrests_t = pca.transform(&us_arrests);
|
let us_arrests_t = pca.transform(&us_arrests).unwrap();
|
||||||
|
|
||||||
assert!(us_arrests_t
|
assert!(us_arrests_t
|
||||||
.abs()
|
.abs()
|
||||||
@@ -530,7 +536,7 @@ mod tests {
|
|||||||
&[5.2, 2.7, 3.9, 1.4],
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let pca = PCA::new(&iris, 4, Default::default());
|
let pca = PCA::fit(&iris, 4, Default::default()).unwrap();
|
||||||
|
|
||||||
let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
|
let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
|
||||||
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
|
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
|
||||||
|
|||||||
@@ -39,8 +39,8 @@
|
|||||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
//! ];
|
//! ];
|
||||||
//!
|
//!
|
||||||
//! let classifier = RandomForestClassifier::fit(&x, &y, Default::default());
|
//! let classifier = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||||
//! let y_hat = classifier.predict(&x); // use the same data for prediction
|
//! let y_hat = classifier.predict(&x).unwrap(); // use the same data for prediction
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
@@ -53,6 +53,7 @@ use std::fmt::Debug;
|
|||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::tree::decision_tree_classifier::{
|
use crate::tree::decision_tree_classifier::{
|
||||||
@@ -126,7 +127,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
parameters: RandomForestClassifierParameters,
|
parameters: RandomForestClassifierParameters,
|
||||||
) -> RandomForestClassifier<T> {
|
) -> Result<RandomForestClassifier<T>, Failed> {
|
||||||
let (_, num_attributes) = x.shape();
|
let (_, num_attributes) = x.shape();
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (_, y_ncols) = y_m.shape();
|
let (_, y_ncols) = y_m.shape();
|
||||||
@@ -158,20 +159,20 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
min_samples_leaf: parameters.min_samples_leaf,
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
min_samples_split: parameters.min_samples_split,
|
min_samples_split: parameters.min_samples_split,
|
||||||
};
|
};
|
||||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
|
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
RandomForestClassifier {
|
Ok(RandomForestClassifier {
|
||||||
parameters: parameters,
|
parameters: parameters,
|
||||||
trees: trees,
|
trees: trees,
|
||||||
classes,
|
classes,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict class for `x`
|
/// Predict class for `x`
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let mut result = M::zeros(1, x.shape().0);
|
let mut result = M::zeros(1, x.shape().0);
|
||||||
|
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
@@ -180,7 +181,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
result.set(0, i, self.classes[self.predict_for_row(x, i)]);
|
result.set(0, i, self.classes[self.predict_for_row(x, i)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
result.to_row_vector()
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||||
@@ -263,9 +264,10 @@ mod tests {
|
|||||||
n_trees: 100,
|
n_trees: 100,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
},
|
},
|
||||||
);
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(accuracy(&y, &classifier.predict(&x)) >= 0.95);
|
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -296,7 +298,7 @@ mod tests {
|
|||||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
];
|
];
|
||||||
|
|
||||||
let forest = RandomForestClassifier::fit(&x, &y, Default::default());
|
let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
let deserialized_forest: RandomForestClassifier<f64> =
|
let deserialized_forest: RandomForestClassifier<f64> =
|
||||||
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||||
|
|||||||
@@ -35,9 +35,9 @@
|
|||||||
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
|
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
|
||||||
//! ];
|
//! ];
|
||||||
//!
|
//!
|
||||||
//! let regressor = RandomForestRegressor::fit(&x, &y, Default::default());
|
//! let regressor = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||||
//!
|
//!
|
||||||
//! let y_hat = regressor.predict(&x); // use the same data for prediction
|
//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
@@ -50,6 +50,7 @@ use std::fmt::Debug;
|
|||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::tree::decision_tree_regressor::{
|
use crate::tree::decision_tree_regressor::{
|
||||||
@@ -114,7 +115,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
parameters: RandomForestRegressorParameters,
|
parameters: RandomForestRegressorParameters,
|
||||||
) -> RandomForestRegressor<T> {
|
) -> Result<RandomForestRegressor<T>, Failed> {
|
||||||
let (n_rows, num_attributes) = x.shape();
|
let (n_rows, num_attributes) = x.shape();
|
||||||
|
|
||||||
let mtry = parameters
|
let mtry = parameters
|
||||||
@@ -130,19 +131,19 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
min_samples_leaf: parameters.min_samples_leaf,
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
min_samples_split: parameters.min_samples_split,
|
min_samples_split: parameters.min_samples_split,
|
||||||
};
|
};
|
||||||
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params);
|
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
RandomForestRegressor {
|
Ok(RandomForestRegressor {
|
||||||
parameters: parameters,
|
parameters: parameters,
|
||||||
trees: trees,
|
trees: trees,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict class for `x`
|
/// Predict class for `x`
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let mut result = M::zeros(1, x.shape().0);
|
let mut result = M::zeros(1, x.shape().0);
|
||||||
|
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
@@ -151,7 +152,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
result.set(0, i, self.predict_for_row(x, i));
|
result.set(0, i, self.predict_for_row(x, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
result.to_row_vector()
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||||
@@ -219,7 +220,8 @@ mod tests {
|
|||||||
m: Option::None,
|
m: Option::None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.predict(&x);
|
.and_then(|rf| rf.predict(&x))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
||||||
}
|
}
|
||||||
@@ -249,7 +251,7 @@ mod tests {
|
|||||||
114.2, 115.7, 116.9,
|
114.2, 115.7, 116.9,
|
||||||
];
|
];
|
||||||
|
|
||||||
let forest = RandomForestRegressor::fit(&x, &y, Default::default());
|
let forest = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
let deserialized_forest: RandomForestRegressor<f64> =
|
let deserialized_forest: RandomForestRegressor<f64> =
|
||||||
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||||
|
|||||||
+81
-39
@@ -2,58 +2,100 @@
|
|||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
/// Error to be raised when model does not fits data.
|
use serde::{Deserialize, Serialize};
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct FitFailedError {
|
/// Generic error to be raised when something goes wrong.
|
||||||
details: String,
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Failed {
|
||||||
|
err: FailedError,
|
||||||
|
msg: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error to be raised when model prediction cannot be calculated.
|
/// Type of error
|
||||||
#[derive(Debug)]
|
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct PredictFailedError {
|
pub enum FailedError {
|
||||||
details: String,
|
/// Can't fit algorithm to data
|
||||||
|
FitFailed = 1,
|
||||||
|
/// Can't predict new values
|
||||||
|
PredictFailed,
|
||||||
|
/// Can't transform data
|
||||||
|
TransformFailed,
|
||||||
|
/// Can't find an item
|
||||||
|
FindFailed,
|
||||||
|
/// Can't decompose a matrix
|
||||||
|
DecompositionFailed,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FitFailedError {
|
impl Failed {
|
||||||
/// Creates new instance of `FitFailedError`
|
///get type of error
|
||||||
/// * `msg` - description of the error
|
#[inline]
|
||||||
pub fn new(msg: &str) -> FitFailedError {
|
pub fn error(&self) -> FailedError {
|
||||||
FitFailedError {
|
self.err
|
||||||
details: msg.to_string(),
|
}
|
||||||
|
|
||||||
|
/// new instance of `FailedError::FitError`
|
||||||
|
pub fn fit(msg: &str) -> Self {
|
||||||
|
Failed {
|
||||||
|
err: FailedError::FitFailed,
|
||||||
|
msg: msg.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// new instance of `FailedError::PredictFailed`
|
||||||
|
pub fn predict(msg: &str) -> Self {
|
||||||
|
Failed {
|
||||||
|
err: FailedError::PredictFailed,
|
||||||
|
msg: msg.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// new instance of `FailedError::TransformFailed`
|
||||||
|
pub fn transform(msg: &str) -> Self {
|
||||||
|
Failed {
|
||||||
|
err: FailedError::TransformFailed,
|
||||||
|
msg: msg.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// new instance of `err`
|
||||||
|
pub fn because(err: FailedError, msg: &str) -> Self {
|
||||||
|
Failed {
|
||||||
|
err: err,
|
||||||
|
msg: msg.to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for FitFailedError {
|
impl PartialEq for FailedError {
|
||||||
|
#[inline(always)]
|
||||||
|
fn eq(&self, rhs: &Self) -> bool {
|
||||||
|
*self as u8 == *rhs as u8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for Failed {
|
||||||
|
#[inline(always)]
|
||||||
|
fn eq(&self, rhs: &Self) -> bool {
|
||||||
|
self.err == rhs.err && self.msg == rhs.msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for FailedError {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
write!(f, "{}", self.details)
|
let failed_err_str = match self {
|
||||||
|
FailedError::FitFailed => "Fit failed",
|
||||||
|
FailedError::PredictFailed => "Predict failed",
|
||||||
|
FailedError::TransformFailed => "Transform failed",
|
||||||
|
FailedError::FindFailed => "Find failed",
|
||||||
|
FailedError::DecompositionFailed => "Decomposition failed",
|
||||||
|
};
|
||||||
|
write!(f, "{}", failed_err_str)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Error for FitFailedError {
|
impl fmt::Display for Failed {
|
||||||
fn description(&self) -> &str {
|
|
||||||
&self.details
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PredictFailedError {
|
|
||||||
/// Creates new instance of `PredictFailedError`
|
|
||||||
/// * `msg` - description of the error
|
|
||||||
pub fn new(msg: &str) -> PredictFailedError {
|
|
||||||
PredictFailedError {
|
|
||||||
details: msg.to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for PredictFailedError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
write!(f, "{}", self.details)
|
write!(f, "{}: {}", self.err, self.msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Error for PredictFailedError {
|
impl Error for Failed {}
|
||||||
fn description(&self) -> &str {
|
|
||||||
&self.details
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
+2
-2
@@ -58,10 +58,10 @@
|
|||||||
//! let y = vec![2., 2., 2., 3., 3.];
|
//! let y = vec![2., 2., 2., 3., 3.];
|
||||||
//!
|
//!
|
||||||
//! // Train classifier
|
//! // Train classifier
|
||||||
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||||
//!
|
//!
|
||||||
//! // Predict classes
|
//! // Predict classes
|
||||||
//! let y_hat = knn.predict(&x);
|
//! let y_hat = knn.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
/// Various algorithms and helper methods that are used elsewhere in SmartCore
|
/// Various algorithms and helper methods that are used elsewhere in SmartCore
|
||||||
|
|||||||
+8
-7
@@ -21,7 +21,7 @@
|
|||||||
//! &[0.7000, 0.3000, 0.8000],
|
//! &[0.7000, 0.3000, 0.8000],
|
||||||
//! ]);
|
//! ]);
|
||||||
//!
|
//!
|
||||||
//! let evd = A.evd(true);
|
//! let evd = A.evd(true).unwrap();
|
||||||
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||||
//! let eigenvalues: Vec<f64> = evd.d;
|
//! let eigenvalues: Vec<f64> = evd.d;
|
||||||
//! ```
|
//! ```
|
||||||
@@ -34,6 +34,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::BaseMatrix;
|
use crate::linalg::BaseMatrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use num::complex::Complex;
|
use num::complex::Complex;
|
||||||
@@ -54,14 +55,14 @@ pub struct EVD<T: RealNumber, M: BaseMatrix<T>> {
|
|||||||
pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||||
/// Compute the eigen decomposition of a square matrix.
|
/// Compute the eigen decomposition of a square matrix.
|
||||||
/// * `symmetric` - whether the matrix is symmetric
|
/// * `symmetric` - whether the matrix is symmetric
|
||||||
fn evd(&self, symmetric: bool) -> EVD<T, Self> {
|
fn evd(&self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
|
||||||
self.clone().evd_mut(symmetric)
|
self.clone().evd_mut(symmetric)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the eigen decomposition of a square matrix. The input matrix
|
/// Compute the eigen decomposition of a square matrix. The input matrix
|
||||||
/// will be used for factorization.
|
/// will be used for factorization.
|
||||||
/// * `symmetric` - whether the matrix is symmetric
|
/// * `symmetric` - whether the matrix is symmetric
|
||||||
fn evd_mut(mut self, symmetric: bool) -> EVD<T, Self> {
|
fn evd_mut(mut self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
|
||||||
let (nrows, ncols) = self.shape();
|
let (nrows, ncols) = self.shape();
|
||||||
if ncols != nrows {
|
if ncols != nrows {
|
||||||
panic!("Matrix is not square: {} x {}", nrows, ncols);
|
panic!("Matrix is not square: {} x {}", nrows, ncols);
|
||||||
@@ -92,7 +93,7 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
sort(&mut d, &mut e, &mut V);
|
sort(&mut d, &mut e, &mut V);
|
||||||
}
|
}
|
||||||
|
|
||||||
EVD { V: V, d: d, e: e }
|
Ok(EVD { V: V, d: d, e: e })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -845,7 +846,7 @@ mod tests {
|
|||||||
&[0.6240573, -0.44947578, -0.6391588],
|
&[0.6240573, -0.44947578, -0.6391588],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let evd = A.evd(true);
|
let evd = A.evd(true).unwrap();
|
||||||
|
|
||||||
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
||||||
for i in 0..eigen_values.len() {
|
for i in 0..eigen_values.len() {
|
||||||
@@ -872,7 +873,7 @@ mod tests {
|
|||||||
&[0.6952105, 0.43984484, -0.7036135],
|
&[0.6952105, 0.43984484, -0.7036135],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let evd = A.evd(false);
|
let evd = A.evd(false).unwrap();
|
||||||
|
|
||||||
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
||||||
for i in 0..eigen_values.len() {
|
for i in 0..eigen_values.len() {
|
||||||
@@ -902,7 +903,7 @@ mod tests {
|
|||||||
&[0.6707, 0.1059, 0.901, -0.6289],
|
&[0.6707, 0.1059, 0.901, -0.6289],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let evd = A.evd(false);
|
let evd = A.evd(false).unwrap();
|
||||||
|
|
||||||
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
||||||
for i in 0..eigen_values_d.len() {
|
for i in 0..eigen_values_d.len() {
|
||||||
|
|||||||
+13
-13
@@ -20,7 +20,7 @@
|
|||||||
//! &[5., 6., 0.]
|
//! &[5., 6., 0.]
|
||||||
//! ]);
|
//! ]);
|
||||||
//!
|
//!
|
||||||
//! let lu = A.lu();
|
//! let lu = A.lu().unwrap();
|
||||||
//! let lower: DenseMatrix<f64> = lu.L();
|
//! let lower: DenseMatrix<f64> = lu.L();
|
||||||
//! let upper: DenseMatrix<f64> = lu.U();
|
//! let upper: DenseMatrix<f64> = lu.U();
|
||||||
//! ```
|
//! ```
|
||||||
@@ -36,6 +36,7 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::BaseMatrix;
|
use crate::linalg::BaseMatrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
@@ -121,7 +122,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns matrix inverse
|
/// Returns matrix inverse
|
||||||
pub fn inverse(&self) -> M {
|
pub fn inverse(&self) -> Result<M, Failed> {
|
||||||
let (m, n) = self.LU.shape();
|
let (m, n) = self.LU.shape();
|
||||||
|
|
||||||
if m != n {
|
if m != n {
|
||||||
@@ -134,11 +135,10 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
|||||||
inv.set(i, i, T::one());
|
inv.set(i, i, T::one());
|
||||||
}
|
}
|
||||||
|
|
||||||
inv = self.solve(inv);
|
self.solve(inv)
|
||||||
return inv;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn solve(&self, mut b: M) -> M {
|
fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||||
let (m, n) = self.LU.shape();
|
let (m, n) = self.LU.shape();
|
||||||
let (b_m, b_n) = b.shape();
|
let (b_m, b_n) = b.shape();
|
||||||
|
|
||||||
@@ -187,20 +187,20 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b
|
Ok(b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait that implements LU decomposition routine for any matrix.
|
/// Trait that implements LU decomposition routine for any matrix.
|
||||||
pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||||
/// Compute the LU decomposition of a square matrix.
|
/// Compute the LU decomposition of a square matrix.
|
||||||
fn lu(&self) -> LU<T, Self> {
|
fn lu(&self) -> Result<LU<T, Self>, Failed> {
|
||||||
self.clone().lu_mut()
|
self.clone().lu_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the LU decomposition of a square matrix. The input matrix
|
/// Compute the LU decomposition of a square matrix. The input matrix
|
||||||
/// will be used for factorization.
|
/// will be used for factorization.
|
||||||
fn lu_mut(mut self) -> LU<T, Self> {
|
fn lu_mut(mut self) -> Result<LU<T, Self>, Failed> {
|
||||||
let (m, n) = self.shape();
|
let (m, n) = self.shape();
|
||||||
|
|
||||||
let mut piv = vec![0; m];
|
let mut piv = vec![0; m];
|
||||||
@@ -252,12 +252,12 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LU::new(self, piv, pivsign)
|
Ok(LU::new(self, piv, pivsign))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Solves Ax = b
|
/// Solves Ax = b
|
||||||
fn lu_solve_mut(self, b: Self) -> Self {
|
fn lu_solve_mut(self, b: Self) -> Result<Self, Failed> {
|
||||||
self.lu_mut().solve(b)
|
self.lu_mut().and_then(|lu| lu.solve(b))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,7 +275,7 @@ mod tests {
|
|||||||
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
|
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
|
||||||
let expected_pivot =
|
let expected_pivot =
|
||||||
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
|
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
|
||||||
let lu = a.lu();
|
let lu = a.lu().unwrap();
|
||||||
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
|
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
|
||||||
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
||||||
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
|
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
|
||||||
@@ -286,7 +286,7 @@ mod tests {
|
|||||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||||
let expected =
|
let expected =
|
||||||
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
|
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
|
||||||
let a_inv = a.lu().inverse();
|
let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||||
println!("{}", a_inv);
|
println!("{}", a_inv);
|
||||||
assert!(a_inv.approximate_eq(&expected, 1e-4));
|
assert!(a_inv.approximate_eq(&expected, 1e-4));
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -26,7 +26,7 @@
|
|||||||
//! &[0.7000, 0.3000, 0.8000],
|
//! &[0.7000, 0.3000, 0.8000],
|
||||||
//! ]);
|
//! ]);
|
||||||
//!
|
//!
|
||||||
//! let svd = A.svd();
|
//! let svd = A.svd().unwrap();
|
||||||
//!
|
//!
|
||||||
//! let s: Vec<f64> = svd.s;
|
//! let s: Vec<f64> = svd.s;
|
||||||
//! let v: DenseMatrix<f64> = svd.V;
|
//! let v: DenseMatrix<f64> = svd.V;
|
||||||
|
|||||||
@@ -34,8 +34,8 @@
|
|||||||
//! 116.9,
|
//! 116.9,
|
||||||
//! ]);
|
//! ]);
|
||||||
//!
|
//!
|
||||||
//! let lr = LinearRegression::fit(&x, &y, Default::default());
|
//! let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
|
||||||
//! let y_hat = lr.predict(&x);
|
//! let y_hat = lr.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
|
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
|
||||||
@@ -777,9 +777,12 @@ mod tests {
|
|||||||
solver: LinearRegressionSolverName::QR,
|
solver: LinearRegressionSolverName::QR,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.predict(&x);
|
.and_then(|lr| lr.predict(&x))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default())
|
||||||
|
.and_then(|lr| lr.predict(&x))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(y
|
assert!(y
|
||||||
.iter()
|
.iter()
|
||||||
|
|||||||
@@ -36,8 +36,8 @@
|
|||||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
|
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
|
||||||
//! ]);
|
//! ]);
|
||||||
//!
|
//!
|
||||||
//! let lr = LogisticRegression::fit(&x, &y);
|
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||||
//! let y_hat = lr.predict(&x);
|
//! let y_hat = lr.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
use std::ops::AddAssign;
|
use std::ops::AddAssign;
|
||||||
@@ -395,6 +395,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::ensemble::random_forest_regressor::*;
|
use crate::ensemble::random_forest_regressor::*;
|
||||||
use crate::linear::logistic_regression::*;
|
use crate::linear::logistic_regression::*;
|
||||||
|
use crate::metrics::mean_absolute_error;
|
||||||
use ndarray::{arr1, arr2, Array1, Array2};
|
use ndarray::{arr1, arr2, Array1, Array2};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -736,9 +737,9 @@ mod tests {
|
|||||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||||
|
|
||||||
let y_hat = lr.predict(&x);
|
let y_hat = lr.predict(&x).unwrap();
|
||||||
|
|
||||||
let error: f64 = y
|
let error: f64 = y
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -774,10 +775,6 @@ mod tests {
|
|||||||
114.2, 115.7, 116.9,
|
114.2, 115.7, 116.9,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let expected_y: Vec<f64> = vec![
|
|
||||||
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
|
|
||||||
];
|
|
||||||
|
|
||||||
let y_hat = RandomForestRegressor::fit(
|
let y_hat = RandomForestRegressor::fit(
|
||||||
&x,
|
&x,
|
||||||
&y,
|
&y,
|
||||||
@@ -789,10 +786,10 @@ mod tests {
|
|||||||
m: Option::None,
|
m: Option::None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.predict(&x);
|
.unwrap()
|
||||||
|
.predict(&x)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
for i in 0..y_hat.len() {
|
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
||||||
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+14
-14
@@ -15,9 +15,9 @@
|
|||||||
//! &[0.7, 0.3, 0.8]
|
//! &[0.7, 0.3, 0.8]
|
||||||
//! ]);
|
//! ]);
|
||||||
//!
|
//!
|
||||||
//! let lu = A.qr();
|
//! let qr = A.qr().unwrap();
|
||||||
//! let orthogonal: DenseMatrix<f64> = lu.Q();
|
//! let orthogonal: DenseMatrix<f64> = qr.Q();
|
||||||
//! let triangular: DenseMatrix<f64> = lu.R();
|
//! let triangular: DenseMatrix<f64> = qr.R();
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! ## References:
|
//! ## References:
|
||||||
@@ -28,10 +28,10 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
|
|
||||||
use std::fmt::Debug;
|
use crate::error::Failed;
|
||||||
|
|
||||||
use crate::linalg::BaseMatrix;
|
use crate::linalg::BaseMatrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
/// Results of QR decomposition.
|
/// Results of QR decomposition.
|
||||||
@@ -99,7 +99,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
|||||||
return Q;
|
return Q;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn solve(&self, mut b: M) -> M {
|
fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||||
let (m, n) = self.QR.shape();
|
let (m, n) = self.QR.shape();
|
||||||
let (b_nrows, b_ncols) = b.shape();
|
let (b_nrows, b_ncols) = b.shape();
|
||||||
|
|
||||||
@@ -139,20 +139,20 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b
|
Ok(b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait that implements QR decomposition routine for any matrix.
|
/// Trait that implements QR decomposition routine for any matrix.
|
||||||
pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||||
/// Compute the QR decomposition of a matrix.
|
/// Compute the QR decomposition of a matrix.
|
||||||
fn qr(&self) -> QR<T, Self> {
|
fn qr(&self) -> Result<QR<T, Self>, Failed> {
|
||||||
self.clone().qr_mut()
|
self.clone().qr_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the QR decomposition of a matrix. The input matrix
|
/// Compute the QR decomposition of a matrix. The input matrix
|
||||||
/// will be used for factorization.
|
/// will be used for factorization.
|
||||||
fn qr_mut(mut self) -> QR<T, Self> {
|
fn qr_mut(mut self) -> Result<QR<T, Self>, Failed> {
|
||||||
let (m, n) = self.shape();
|
let (m, n) = self.shape();
|
||||||
|
|
||||||
let mut r_diagonal: Vec<T> = vec![T::zero(); n];
|
let mut r_diagonal: Vec<T> = vec![T::zero(); n];
|
||||||
@@ -186,12 +186,12 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
r_diagonal[k] = -nrm;
|
r_diagonal[k] = -nrm;
|
||||||
}
|
}
|
||||||
|
|
||||||
QR::new(self, r_diagonal)
|
Ok(QR::new(self, r_diagonal))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Solves Ax = b
|
/// Solves Ax = b
|
||||||
fn qr_solve_mut(self, b: Self) -> Self {
|
fn qr_solve_mut(self, b: Self) -> Result<Self, Failed> {
|
||||||
self.qr_mut().solve(b)
|
self.qr_mut().and_then(|qr| qr.solve(b))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,7 +213,7 @@ mod tests {
|
|||||||
&[0.0, -0.3064, 0.0682],
|
&[0.0, -0.3064, 0.0682],
|
||||||
&[0.0, 0.0, -0.1999],
|
&[0.0, 0.0, -0.1999],
|
||||||
]);
|
]);
|
||||||
let qr = a.qr();
|
let qr = a.qr().unwrap();
|
||||||
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
|
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
|
||||||
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
|
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
|
||||||
}
|
}
|
||||||
@@ -227,7 +227,7 @@ mod tests {
|
|||||||
&[0.8783784, 2.2297297],
|
&[0.8783784, 2.2297297],
|
||||||
&[0.4729730, 0.6621622],
|
&[0.4729730, 0.6621622],
|
||||||
]);
|
]);
|
||||||
let w = a.qr_solve_mut(b);
|
let w = a.qr_solve_mut(b).unwrap();
|
||||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+14
-13
@@ -19,7 +19,7 @@
|
|||||||
//! &[0.7, 0.3, 0.8]
|
//! &[0.7, 0.3, 0.8]
|
||||||
//! ]);
|
//! ]);
|
||||||
//!
|
//!
|
||||||
//! let svd = A.svd();
|
//! let svd = A.svd().unwrap();
|
||||||
//! let u: DenseMatrix<f64> = svd.U;
|
//! let u: DenseMatrix<f64> = svd.U;
|
||||||
//! let v: DenseMatrix<f64> = svd.V;
|
//! let v: DenseMatrix<f64> = svd.V;
|
||||||
//! let s: Vec<f64> = svd.s;
|
//! let s: Vec<f64> = svd.s;
|
||||||
@@ -33,6 +33,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::BaseMatrix;
|
use crate::linalg::BaseMatrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
@@ -55,23 +56,23 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
|
|||||||
/// Trait that implements SVD decomposition routine for any matrix.
|
/// Trait that implements SVD decomposition routine for any matrix.
|
||||||
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||||
/// Solves Ax = b. Overrides original matrix in the process.
|
/// Solves Ax = b. Overrides original matrix in the process.
|
||||||
fn svd_solve_mut(self, b: Self) -> Self {
|
fn svd_solve_mut(self, b: Self) -> Result<Self, Failed> {
|
||||||
self.svd_mut().solve(b)
|
self.svd_mut().and_then(|svd| svd.solve(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Solves Ax = b
|
/// Solves Ax = b
|
||||||
fn svd_solve(&self, b: Self) -> Self {
|
fn svd_solve(&self, b: Self) -> Result<Self, Failed> {
|
||||||
self.svd().solve(b)
|
self.svd().and_then(|svd| svd.solve(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the SVD decomposition of a matrix.
|
/// Compute the SVD decomposition of a matrix.
|
||||||
fn svd(&self) -> SVD<T, Self> {
|
fn svd(&self) -> Result<SVD<T, Self>, Failed> {
|
||||||
self.clone().svd_mut()
|
self.clone().svd_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the SVD decomposition of a matrix. The input matrix
|
/// Compute the SVD decomposition of a matrix. The input matrix
|
||||||
/// will be used for factorization.
|
/// will be used for factorization.
|
||||||
fn svd_mut(self) -> SVD<T, Self> {
|
fn svd_mut(self) -> Result<SVD<T, Self>, Failed> {
|
||||||
let mut U = self;
|
let mut U = self;
|
||||||
|
|
||||||
let (m, n) = U.shape();
|
let (m, n) = U.shape();
|
||||||
@@ -406,7 +407,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SVD::new(U, v, w)
|
Ok(SVD::new(U, v, w))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -427,7 +428,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn solve(&self, mut b: M) -> M {
|
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||||
let p = b.shape().1;
|
let p = b.shape().1;
|
||||||
|
|
||||||
if self.U.shape().0 != b.shape().0 {
|
if self.U.shape().0 != b.shape().0 {
|
||||||
@@ -460,7 +461,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b
|
Ok(b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -491,7 +492,7 @@ mod tests {
|
|||||||
&[0.6240573, -0.44947578, -0.6391588],
|
&[0.6240573, -0.44947578, -0.6391588],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let svd = A.svd();
|
let svd = A.svd().unwrap();
|
||||||
|
|
||||||
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
|
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
|
||||||
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
|
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
|
||||||
@@ -692,7 +693,7 @@ mod tests {
|
|||||||
],
|
],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let svd = A.svd();
|
let svd = A.svd().unwrap();
|
||||||
|
|
||||||
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
|
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
|
||||||
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
|
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
|
||||||
@@ -707,7 +708,7 @@ mod tests {
|
|||||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
||||||
let expected_w =
|
let expected_w =
|
||||||
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
|
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
|
||||||
let w = a.svd_solve_mut(b);
|
let w = a.svd_solve_mut(b).unwrap();
|
||||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,9 +47,9 @@
|
|||||||
//!
|
//!
|
||||||
//! let lr = LinearRegression::fit(&x, &y, LinearRegressionParameters {
|
//! let lr = LinearRegression::fit(&x, &y, LinearRegressionParameters {
|
||||||
//! solver: LinearRegressionSolverName::QR, // or SVD
|
//! solver: LinearRegressionSolverName::QR, // or SVD
|
||||||
//! });
|
//! }).unwrap();
|
||||||
//!
|
//!
|
||||||
//! let y_hat = lr.predict(&x);
|
//! let y_hat = lr.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! ## References:
|
//! ## References:
|
||||||
@@ -64,6 +64,7 @@ use std::fmt::Debug;
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
@@ -115,39 +116,41 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
|||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
parameters: LinearRegressionParameters,
|
parameters: LinearRegressionParameters,
|
||||||
) -> LinearRegression<T, M> {
|
) -> Result<LinearRegression<T, M>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let b = y_m.transpose();
|
let b = y_m.transpose();
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let (y_nrows, _) = b.shape();
|
let (y_nrows, _) = b.shape();
|
||||||
|
|
||||||
if x_nrows != y_nrows {
|
if x_nrows != y_nrows {
|
||||||
panic!("Number of rows of X doesn't match number of rows of Y");
|
return Err(Failed::fit(&format!(
|
||||||
|
"Number of rows of X doesn't match number of rows of Y"
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let a = x.h_stack(&M::ones(x_nrows, 1));
|
let a = x.h_stack(&M::ones(x_nrows, 1));
|
||||||
|
|
||||||
let w = match parameters.solver {
|
let w = match parameters.solver {
|
||||||
LinearRegressionSolverName::QR => a.qr_solve_mut(b),
|
LinearRegressionSolverName::QR => a.qr_solve_mut(b)?,
|
||||||
LinearRegressionSolverName::SVD => a.svd_solve_mut(b),
|
LinearRegressionSolverName::SVD => a.svd_solve_mut(b)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let wights = w.slice(0..num_attributes, 0..1);
|
let wights = w.slice(0..num_attributes, 0..1);
|
||||||
|
|
||||||
LinearRegression {
|
Ok(LinearRegression {
|
||||||
intercept: w.get(num_attributes, 0),
|
intercept: w.get(num_attributes, 0),
|
||||||
coefficients: wights,
|
coefficients: wights,
|
||||||
solver: parameters.solver,
|
solver: parameters.solver,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict target values from `x`
|
/// Predict target values from `x`
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn predict(&self, x: &M) -> M::RowVector {
|
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let (nrows, _) = x.shape();
|
let (nrows, _) = x.shape();
|
||||||
let mut y_hat = x.matmul(&self.coefficients);
|
let mut y_hat = x.matmul(&self.coefficients);
|
||||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||||
y_hat.transpose().to_row_vector()
|
Ok(y_hat.transpose().to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get estimates regression coefficients
|
/// Get estimates regression coefficients
|
||||||
@@ -199,9 +202,12 @@ mod tests {
|
|||||||
solver: LinearRegressionSolverName::QR,
|
solver: LinearRegressionSolverName::QR,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.predict(&x);
|
.and_then(|lr| lr.predict(&x))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default())
|
||||||
|
.and_then(|lr| lr.predict(&x))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(y
|
assert!(y
|
||||||
.iter()
|
.iter()
|
||||||
@@ -239,7 +245,7 @@ mod tests {
|
|||||||
114.2, 115.7, 116.9,
|
114.2, 115.7, 116.9,
|
||||||
];
|
];
|
||||||
|
|
||||||
let lr = LinearRegression::fit(&x, &y, Default::default());
|
let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
|
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
|
||||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||||
|
|||||||
@@ -40,9 +40,9 @@
|
|||||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
//! ];
|
//! ];
|
||||||
//!
|
//!
|
||||||
//! let lr = LogisticRegression::fit(&x, &y);
|
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||||
//!
|
//!
|
||||||
//! let y_hat = lr.predict(&x);
|
//! let y_hat = lr.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! ## References:
|
//! ## References:
|
||||||
@@ -57,6 +57,7 @@ use std::marker::PhantomData;
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::optimization::first_order::lbfgs::LBFGS;
|
use crate::optimization::first_order::lbfgs::LBFGS;
|
||||||
@@ -208,13 +209,15 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
/// Fits Logistic Regression to your data.
|
/// Fits Logistic Regression to your data.
|
||||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||||
/// * `y` - target class values
|
/// * `y` - target class values
|
||||||
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M> {
|
pub fn fit(x: &M, y: &M::RowVector) -> Result<LogisticRegression<T, M>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let (_, y_nrows) = y_m.shape();
|
let (_, y_nrows) = y_m.shape();
|
||||||
|
|
||||||
if x_nrows != y_nrows {
|
if x_nrows != y_nrows {
|
||||||
panic!("Number of rows of X doesn't match number of rows of Y");
|
return Err(Failed::fit(&format!(
|
||||||
|
"Number of rows of X doesn't match number of rows of Y"
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let classes = y_m.unique();
|
let classes = y_m.unique();
|
||||||
@@ -229,7 +232,10 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if k < 2 {
|
if k < 2 {
|
||||||
panic!("Incorrect number of classes: {}", k);
|
Err(Failed::fit(&format!(
|
||||||
|
"incorrect number of classes: {}. Should be >= 2.",
|
||||||
|
k
|
||||||
|
)))
|
||||||
} else if k == 2 {
|
} else if k == 2 {
|
||||||
let x0 = M::zeros(1, num_attributes + 1);
|
let x0 = M::zeros(1, num_attributes + 1);
|
||||||
|
|
||||||
@@ -241,12 +247,12 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
|
|
||||||
let result = LogisticRegression::minimize(x0, objective);
|
let result = LogisticRegression::minimize(x0, objective);
|
||||||
|
|
||||||
LogisticRegression {
|
Ok(LogisticRegression {
|
||||||
weights: result.x,
|
weights: result.x,
|
||||||
classes: classes,
|
classes: classes,
|
||||||
num_attributes: num_attributes,
|
num_attributes: num_attributes,
|
||||||
num_classes: k,
|
num_classes: k,
|
||||||
}
|
})
|
||||||
} else {
|
} else {
|
||||||
let x0 = M::zeros(1, (num_attributes + 1) * k);
|
let x0 = M::zeros(1, (num_attributes + 1) * k);
|
||||||
|
|
||||||
@@ -261,18 +267,18 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
|
|
||||||
let weights = result.x.reshape(k, num_attributes + 1);
|
let weights = result.x.reshape(k, num_attributes + 1);
|
||||||
|
|
||||||
LogisticRegression {
|
Ok(LogisticRegression {
|
||||||
weights: weights,
|
weights: weights,
|
||||||
classes: classes,
|
classes: classes,
|
||||||
num_attributes: num_attributes,
|
num_attributes: num_attributes,
|
||||||
num_classes: k,
|
num_classes: k,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict class labels for samples in `x`.
|
/// Predict class labels for samples in `x`.
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn predict(&self, x: &M) -> M::RowVector {
|
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let n = x.shape().0;
|
let n = x.shape().0;
|
||||||
let mut result = M::zeros(1, n);
|
let mut result = M::zeros(1, n);
|
||||||
if self.num_classes == 2 {
|
if self.num_classes == 2 {
|
||||||
@@ -297,7 +303,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
result.set(0, i, self.classes[class_idxs[i]]);
|
result.set(0, i, self.classes[class_idxs[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result.to_row_vector()
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get estimates regression coefficients
|
/// Get estimates regression coefficients
|
||||||
@@ -444,7 +450,7 @@ mod tests {
|
|||||||
]);
|
]);
|
||||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||||
|
|
||||||
assert_eq!(lr.coefficients().shape(), (3, 2));
|
assert_eq!(lr.coefficients().shape(), (3, 2));
|
||||||
assert_eq!(lr.intercept().shape(), (3, 1));
|
assert_eq!(lr.intercept().shape(), (3, 1));
|
||||||
@@ -452,7 +458,7 @@ mod tests {
|
|||||||
assert!((lr.coefficients().get(0, 0) - 0.0435).abs() < 1e-4);
|
assert!((lr.coefficients().get(0, 0) - 0.0435).abs() < 1e-4);
|
||||||
assert!((lr.intercept().get(0, 0) - 0.1250).abs() < 1e-4);
|
assert!((lr.intercept().get(0, 0) - 0.1250).abs() < 1e-4);
|
||||||
|
|
||||||
let y_hat = lr.predict(&x);
|
let y_hat = lr.predict(&x).unwrap();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
y_hat,
|
y_hat,
|
||||||
@@ -481,7 +487,7 @@ mod tests {
|
|||||||
]);
|
]);
|
||||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||||
|
|
||||||
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
|
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
|
||||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||||
@@ -517,9 +523,9 @@ mod tests {
|
|||||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
];
|
];
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||||
|
|
||||||
let y_hat = lr.predict(&x);
|
let y_hat = lr.predict(&x).unwrap();
|
||||||
|
|
||||||
let error: f64 = y
|
let error: f64 = y
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ impl<T: RealNumber, M: Matrix<T>> Mahalanobis<T, M> {
|
|||||||
/// * `data` - a matrix of _NxM_ where _N_ is number of observations and _M_ is number of attributes
|
/// * `data` - a matrix of _NxM_ where _N_ is number of observations and _M_ is number of attributes
|
||||||
pub fn new(data: &M) -> Mahalanobis<T, M> {
|
pub fn new(data: &M) -> Mahalanobis<T, M> {
|
||||||
let sigma = data.cov();
|
let sigma = data.cov();
|
||||||
let sigmaInv = sigma.lu().inverse();
|
let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||||
Mahalanobis {
|
Mahalanobis {
|
||||||
sigma: sigma,
|
sigma: sigma,
|
||||||
sigmaInv: sigmaInv,
|
sigmaInv: sigmaInv,
|
||||||
@@ -78,7 +78,7 @@ impl<T: RealNumber, M: Matrix<T>> Mahalanobis<T, M> {
|
|||||||
/// * `cov` - a covariance matrix
|
/// * `cov` - a covariance matrix
|
||||||
pub fn new_from_covariance(cov: &M) -> Mahalanobis<T, M> {
|
pub fn new_from_covariance(cov: &M) -> Mahalanobis<T, M> {
|
||||||
let sigma = cov.clone();
|
let sigma = cov.clone();
|
||||||
let sigmaInv = sigma.lu().inverse();
|
let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||||
Mahalanobis {
|
Mahalanobis {
|
||||||
sigma: sigma,
|
sigma: sigma,
|
||||||
sigmaInv: sigmaInv,
|
sigmaInv: sigmaInv,
|
||||||
|
|||||||
+2
-2
@@ -41,9 +41,9 @@
|
|||||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
//! ];
|
//! ];
|
||||||
//!
|
//!
|
||||||
//! let lr = LogisticRegression::fit(&x, &y);
|
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||||
//!
|
//!
|
||||||
//! let y_hat = lr.predict(&x);
|
//! let y_hat = lr.predict(&x).unwrap();
|
||||||
//!
|
//!
|
||||||
//! let acc = ClassificationMetrics::accuracy().get_score(&y, &y_hat);
|
//! let acc = ClassificationMetrics::accuracy().get_score(&y, &y_hat);
|
||||||
//! // or
|
//! // or
|
||||||
|
|||||||
@@ -25,8 +25,8 @@
|
|||||||
//! &[9., 10.]]);
|
//! &[9., 10.]]);
|
||||||
//! let y = vec![2., 2., 2., 3., 3.]; //your class labels
|
//! let y = vec![2., 2., 2., 3., 3.]; //your class labels
|
||||||
//!
|
//!
|
||||||
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||||
//! let y_hat = knn.predict(&x);
|
//! let y_hat = knn.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! variable `y_hat` will hold a vector with estimates of class labels
|
//! variable `y_hat` will hold a vector with estimates of class labels
|
||||||
@@ -34,6 +34,7 @@
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::{row_iter, Matrix};
|
use crate::linalg::{row_iter, Matrix};
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -106,7 +107,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
|||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
distance: D,
|
distance: D,
|
||||||
parameters: KNNClassifierParameters,
|
parameters: KNNClassifierParameters,
|
||||||
) -> KNNClassifier<T, D> {
|
) -> Result<KNNClassifier<T, D>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
|
|
||||||
let (_, y_n) = y_m.shape();
|
let (_, y_n) = y_m.shape();
|
||||||
@@ -122,43 +123,44 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
|||||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(
|
if x_n != y_n {
|
||||||
x_n == y_n,
|
return Err(Failed::fit(&format!(
|
||||||
format!(
|
|
||||||
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||||
x_n, y_n
|
x_n, y_n
|
||||||
)
|
)));
|
||||||
);
|
}
|
||||||
|
|
||||||
assert!(
|
if parameters.k <= 1 {
|
||||||
parameters.k > 1,
|
return Err(Failed::fit(&format!(
|
||||||
format!("k should be > 1, k=[{}]", parameters.k)
|
"k should be > 1, k=[{}]",
|
||||||
);
|
parameters.k
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
KNNClassifier {
|
Ok(KNNClassifier {
|
||||||
classes: classes,
|
classes: classes,
|
||||||
y: yi,
|
y: yi,
|
||||||
k: parameters.k,
|
k: parameters.k,
|
||||||
knn_algorithm: parameters.algorithm.fit(data, distance),
|
knn_algorithm: parameters.algorithm.fit(data, distance)?,
|
||||||
weight: parameters.weight,
|
weight: parameters.weight,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Estimates the class labels for the provided data.
|
/// Estimates the class labels for the provided data.
|
||||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||||
/// Returns a vector of size N with class estimates.
|
/// Returns a vector of size N with class estimates.
|
||||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let mut result = M::zeros(1, x.shape().0);
|
let mut result = M::zeros(1, x.shape().0);
|
||||||
|
|
||||||
row_iter(x)
|
for (i, x) in row_iter(x).enumerate() {
|
||||||
.enumerate()
|
result.set(0, i, self.classes[self.predict_for_row(x)?]);
|
||||||
.for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)]));
|
}
|
||||||
|
|
||||||
result.to_row_vector()
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn predict_for_row(&self, x: Vec<T>) -> usize {
|
fn predict_for_row(&self, x: Vec<T>) -> Result<usize, Failed> {
|
||||||
let search_result = self.knn_algorithm.find(&x, self.k);
|
let search_result = self.knn_algorithm.find(&x, self.k)?;
|
||||||
|
|
||||||
let weights = self
|
let weights = self
|
||||||
.weight
|
.weight
|
||||||
@@ -176,7 +178,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
max_i
|
Ok(max_i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,8 +193,8 @@ mod tests {
|
|||||||
let x =
|
let x =
|
||||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
let y = vec![2., 2., 2., 3., 3.];
|
let y = vec![2., 2., 2., 3., 3.];
|
||||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||||
let y_hat = knn.predict(&x);
|
let y_hat = knn.predict(&x).unwrap();
|
||||||
assert_eq!(5, Vec::len(&y_hat));
|
assert_eq!(5, Vec::len(&y_hat));
|
||||||
assert_eq!(y.to_vec(), y_hat);
|
assert_eq!(y.to_vec(), y_hat);
|
||||||
}
|
}
|
||||||
@@ -210,8 +212,9 @@ mod tests {
|
|||||||
algorithm: KNNAlgorithmName::LinearSearch,
|
algorithm: KNNAlgorithmName::LinearSearch,
|
||||||
weight: KNNWeightFunction::Distance,
|
weight: KNNWeightFunction::Distance,
|
||||||
},
|
},
|
||||||
);
|
)
|
||||||
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]]));
|
.unwrap();
|
||||||
|
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]])).unwrap();
|
||||||
assert_eq!(vec![3.0], y_hat);
|
assert_eq!(vec![3.0], y_hat);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,7 +224,7 @@ mod tests {
|
|||||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
let y = vec![2., 2., 2., 3., 3.];
|
let y = vec![2., 2., 2., 3., 3.];
|
||||||
|
|
||||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||||
|
|
||||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -27,8 +27,8 @@
|
|||||||
//! &[5., 5.]]);
|
//! &[5., 5.]]);
|
||||||
//! let y = vec![1., 2., 3., 4., 5.]; //your target values
|
//! let y = vec![1., 2., 3., 4., 5.]; //your target values
|
||||||
//!
|
//!
|
||||||
//! let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
//! let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||||
//! let y_hat = knn.predict(&x);
|
//! let y_hat = knn.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! variable `y_hat` will hold predicted value
|
//! variable `y_hat` will hold predicted value
|
||||||
@@ -36,6 +36,7 @@
|
|||||||
//!
|
//!
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::{row_iter, BaseVector, Matrix};
|
use crate::linalg::{row_iter, BaseVector, Matrix};
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -99,7 +100,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
|||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
distance: D,
|
distance: D,
|
||||||
parameters: KNNRegressorParameters,
|
parameters: KNNRegressorParameters,
|
||||||
) -> KNNRegressor<T, D> {
|
) -> Result<KNNRegressor<T, D>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
|
|
||||||
let (_, y_n) = y_m.shape();
|
let (_, y_n) = y_m.shape();
|
||||||
@@ -107,42 +108,43 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
|||||||
|
|
||||||
let data = row_iter(x).collect();
|
let data = row_iter(x).collect();
|
||||||
|
|
||||||
assert!(
|
if x_n != y_n {
|
||||||
x_n == y_n,
|
return Err(Failed::fit(&format!(
|
||||||
format!(
|
|
||||||
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||||
x_n, y_n
|
x_n, y_n
|
||||||
)
|
)));
|
||||||
);
|
}
|
||||||
|
|
||||||
assert!(
|
if parameters.k <= 1 {
|
||||||
parameters.k > 1,
|
return Err(Failed::fit(&format!(
|
||||||
format!("k should be > 1, k=[{}]", parameters.k)
|
"k should be > 1, k=[{}]",
|
||||||
);
|
parameters.k
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
KNNRegressor {
|
Ok(KNNRegressor {
|
||||||
y: y.to_vec(),
|
y: y.to_vec(),
|
||||||
k: parameters.k,
|
k: parameters.k,
|
||||||
knn_algorithm: parameters.algorithm.fit(data, distance),
|
knn_algorithm: parameters.algorithm.fit(data, distance)?,
|
||||||
weight: parameters.weight,
|
weight: parameters.weight,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict the target for the provided data.
|
/// Predict the target for the provided data.
|
||||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||||
/// Returns a vector of size N with estimates.
|
/// Returns a vector of size N with estimates.
|
||||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let mut result = M::zeros(1, x.shape().0);
|
let mut result = M::zeros(1, x.shape().0);
|
||||||
|
|
||||||
row_iter(x)
|
for (i, x) in row_iter(x).enumerate() {
|
||||||
.enumerate()
|
result.set(0, i, self.predict_for_row(x)?);
|
||||||
.for_each(|(i, x)| result.set(0, i, self.predict_for_row(x)));
|
}
|
||||||
|
|
||||||
result.to_row_vector()
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn predict_for_row(&self, x: Vec<T>) -> T {
|
fn predict_for_row(&self, x: Vec<T>) -> Result<T, Failed> {
|
||||||
let search_result = self.knn_algorithm.find(&x, self.k);
|
let search_result = self.knn_algorithm.find(&x, self.k)?;
|
||||||
let mut result = T::zero();
|
let mut result = T::zero();
|
||||||
|
|
||||||
let weights = self
|
let weights = self
|
||||||
@@ -154,7 +156,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
|||||||
result = result + self.y[r.0] * (*w / w_sum);
|
result = result + self.y[r.0] * (*w / w_sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
result
|
Ok(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,8 +181,9 @@ mod tests {
|
|||||||
algorithm: KNNAlgorithmName::LinearSearch,
|
algorithm: KNNAlgorithmName::LinearSearch,
|
||||||
weight: KNNWeightFunction::Distance,
|
weight: KNNWeightFunction::Distance,
|
||||||
},
|
},
|
||||||
);
|
)
|
||||||
let y_hat = knn.predict(&x);
|
.unwrap();
|
||||||
|
let y_hat = knn.predict(&x).unwrap();
|
||||||
assert_eq!(5, Vec::len(&y_hat));
|
assert_eq!(5, Vec::len(&y_hat));
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON);
|
assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON);
|
||||||
@@ -193,8 +196,8 @@ mod tests {
|
|||||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||||
let y_exp = vec![2., 2., 3., 4., 4.];
|
let y_exp = vec![2., 2., 3., 4., 4.];
|
||||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||||
let y_hat = knn.predict(&x);
|
let y_hat = knn.predict(&x).unwrap();
|
||||||
assert_eq!(5, Vec::len(&y_hat));
|
assert_eq!(5, Vec::len(&y_hat));
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - y_exp[i]).abs() < 1e-7);
|
assert!((y_hat[i] - y_exp[i]).abs() < 1e-7);
|
||||||
@@ -207,7 +210,7 @@ mod tests {
|
|||||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
let y = vec![1., 2., 3., 4., 5.];
|
let y = vec![1., 2., 3., 4., 5.];
|
||||||
|
|
||||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||||
|
|
||||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@
|
|||||||
|
|
||||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -93,18 +94,20 @@ impl KNNAlgorithmName {
|
|||||||
&self,
|
&self,
|
||||||
data: Vec<Vec<T>>,
|
data: Vec<Vec<T>>,
|
||||||
distance: D,
|
distance: D,
|
||||||
) -> KNNAlgorithm<T, D> {
|
) -> Result<KNNAlgorithm<T, D>, Failed> {
|
||||||
match *self {
|
match *self {
|
||||||
KNNAlgorithmName::LinearSearch => {
|
KNNAlgorithmName::LinearSearch => {
|
||||||
KNNAlgorithm::LinearSearch(LinearKNNSearch::new(data, distance))
|
LinearKNNSearch::new(data, distance).map(|a| KNNAlgorithm::LinearSearch(a))
|
||||||
|
}
|
||||||
|
KNNAlgorithmName::CoverTree => {
|
||||||
|
CoverTree::new(data, distance).map(|a| KNNAlgorithm::CoverTree(a))
|
||||||
}
|
}
|
||||||
KNNAlgorithmName::CoverTree => KNNAlgorithm::CoverTree(CoverTree::new(data, distance)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||||
fn find(&self, from: &Vec<T>, k: usize) -> Vec<(usize, T)> {
|
fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T)>, Failed> {
|
||||||
match *self {
|
match *self {
|
||||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
||||||
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
||||||
|
|||||||
@@ -50,9 +50,9 @@
|
|||||||
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
|
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||||
//!
|
//!
|
||||||
//! let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
//! let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||||
//!
|
//!
|
||||||
//! let y_hat = tree.predict(&x); // use the same data for prediction
|
//! let y_hat = tree.predict(&x).unwrap(); // use the same data for prediction
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//!
|
//!
|
||||||
@@ -71,6 +71,7 @@ use rand::seq::SliceRandom;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
@@ -276,7 +277,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
parameters: DecisionTreeClassifierParameters,
|
parameters: DecisionTreeClassifierParameters,
|
||||||
) -> DecisionTreeClassifier<T> {
|
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
@@ -288,14 +289,17 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
samples: Vec<usize>,
|
samples: Vec<usize>,
|
||||||
mtry: usize,
|
mtry: usize,
|
||||||
parameters: DecisionTreeClassifierParameters,
|
parameters: DecisionTreeClassifierParameters,
|
||||||
) -> DecisionTreeClassifier<T> {
|
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (_, y_ncols) = y_m.shape();
|
let (_, y_ncols) = y_m.shape();
|
||||||
let (_, num_attributes) = x.shape();
|
let (_, num_attributes) = x.shape();
|
||||||
let classes = y_m.unique();
|
let classes = y_m.unique();
|
||||||
let k = classes.len();
|
let k = classes.len();
|
||||||
if k < 2 {
|
if k < 2 {
|
||||||
panic!("Incorrect number of classes: {}. Should be >= 2.", k);
|
return Err(Failed::fit(&format!(
|
||||||
|
"Incorrect number of classes: {}. Should be >= 2.",
|
||||||
|
k
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||||
@@ -343,12 +347,12 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
tree
|
Ok(tree)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict class value for `x`.
|
/// Predict class value for `x`.
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let mut result = M::zeros(1, x.shape().0);
|
let mut result = M::zeros(1, x.shape().0);
|
||||||
|
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
@@ -357,7 +361,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
result.set(0, i, self.classes[self.predict_for_row(x, i)]);
|
result.set(0, i, self.classes[self.predict_for_row(x, i)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
result.to_row_vector()
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||||
@@ -637,7 +641,9 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
y,
|
y,
|
||||||
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)
|
DecisionTreeClassifier::fit(&x, &y, Default::default())
|
||||||
|
.and_then(|t| t.predict(&x))
|
||||||
|
.unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -652,6 +658,7 @@ mod tests {
|
|||||||
min_samples_split: 2
|
min_samples_split: 2
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
.unwrap()
|
||||||
.depth
|
.depth
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -686,7 +693,9 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
y,
|
y,
|
||||||
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)
|
DecisionTreeClassifier::fit(&x, &y, Default::default())
|
||||||
|
.and_then(|t| t.predict(&x))
|
||||||
|
.unwrap()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -718,7 +727,7 @@ mod tests {
|
|||||||
1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
|
1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
|
||||||
];
|
];
|
||||||
|
|
||||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
let deserialized_tree: DecisionTreeClassifier<f64> =
|
let deserialized_tree: DecisionTreeClassifier<f64> =
|
||||||
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||||
|
|||||||
@@ -45,9 +45,9 @@
|
|||||||
//! 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9,
|
//! 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9,
|
||||||
//! ];
|
//! ];
|
||||||
//!
|
//!
|
||||||
//! let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
//! let tree = DecisionTreeRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||||
//!
|
//!
|
||||||
//! let y_hat = tree.predict(&x); // use the same data for prediction
|
//! let y_hat = tree.predict(&x).unwrap(); // use the same data for prediction
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! ## References:
|
//! ## References:
|
||||||
@@ -66,6 +66,7 @@ use rand::seq::SliceRandom;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
@@ -196,7 +197,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
parameters: DecisionTreeRegressorParameters,
|
parameters: DecisionTreeRegressorParameters,
|
||||||
) -> DecisionTreeRegressor<T> {
|
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
@@ -208,16 +209,11 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
samples: Vec<usize>,
|
samples: Vec<usize>,
|
||||||
mtry: usize,
|
mtry: usize,
|
||||||
parameters: DecisionTreeRegressorParameters,
|
parameters: DecisionTreeRegressorParameters,
|
||||||
) -> DecisionTreeRegressor<T> {
|
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
|
|
||||||
let (_, y_ncols) = y_m.shape();
|
let (_, y_ncols) = y_m.shape();
|
||||||
let (_, num_attributes) = x.shape();
|
let (_, num_attributes) = x.shape();
|
||||||
let classes = y_m.unique();
|
|
||||||
let k = classes.len();
|
|
||||||
if k < 2 {
|
|
||||||
panic!("Incorrect number of classes: {}. Should be >= 2.", k);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||||
|
|
||||||
@@ -257,12 +253,12 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
tree
|
Ok(tree)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict regression value for `x`.
|
/// Predict regression value for `x`.
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let mut result = M::zeros(1, x.shape().0);
|
let mut result = M::zeros(1, x.shape().0);
|
||||||
|
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
@@ -271,7 +267,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
result.set(0, i, self.predict_for_row(x, i));
|
result.set(0, i, self.predict_for_row(x, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
result.to_row_vector()
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||||
@@ -498,7 +494,9 @@ mod tests {
|
|||||||
114.2, 115.7, 116.9,
|
114.2, 115.7, 116.9,
|
||||||
];
|
];
|
||||||
|
|
||||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
|
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default())
|
||||||
|
.and_then(|t| t.predict(&x))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - y[i]).abs() < 0.1);
|
assert!((y_hat[i] - y[i]).abs() < 0.1);
|
||||||
@@ -517,7 +515,8 @@ mod tests {
|
|||||||
min_samples_split: 6,
|
min_samples_split: 6,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.predict(&x);
|
.and_then(|t| t.predict(&x))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||||
@@ -536,7 +535,8 @@ mod tests {
|
|||||||
min_samples_split: 3,
|
min_samples_split: 3,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.predict(&x);
|
.and_then(|t| t.predict(&x))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||||
@@ -568,7 +568,7 @@ mod tests {
|
|||||||
114.2, 115.7, 116.9,
|
114.2, 115.7, 116.9,
|
||||||
];
|
];
|
||||||
|
|
||||||
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
let deserialized_tree: DecisionTreeRegressor<f64> =
|
let deserialized_tree: DecisionTreeRegressor<f64> =
|
||||||
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||||
|
|||||||
Reference in New Issue
Block a user