diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 46f09c9..c560b78 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -83,6 +83,21 @@ pub trait BaseVector: Clone + Debug { self.len() == 0 } + /// Create a new vector from a &[T] + /// ``` + /// use smartcore::linalg::naive::dense_matrix::*; + /// let a: [f64; 5] = [0., 0.5, 2., 3., 4.]; + /// let v: Vec = BaseVector::from_array(&a); + /// assert_eq!(v, vec![0., 0.5, 2., 3., 4.]); + /// ``` + fn from_array(f: &[T]) -> Self { + let mut v = Self::zeros(f.len()); + for (i, elem) in f.iter().enumerate() { + v.set(i, *elem); + } + v + } + /// Return a vector with the elements of the one-dimensional array. fn to_vec(&self) -> Vec; diff --git a/src/math/vector.rs b/src/math/vector.rs index accfed6..62cf63b 100644 --- a/src/math/vector.rs +++ b/src/math/vector.rs @@ -1,13 +1,14 @@ use crate::math::num::RealNumber; use std::collections::HashMap; +use crate::linalg::BaseVector; pub trait RealNumberVector { - fn unique(&self) -> (Vec, Vec); + fn unique_with_indices(&self) -> (Vec, Vec); } -impl RealNumberVector for Vec { - fn unique(&self) -> (Vec, Vec) { - let mut unique = self.clone(); +impl> RealNumberVector for V { + fn unique_with_indices(&self) -> (Vec, Vec) { + let mut unique = self.to_vec(); unique.sort_by(|a, b| a.partial_cmp(b).unwrap()); unique.dedup(); @@ -17,8 +18,8 @@ impl RealNumberVector for Vec { } let mut unique_index = Vec::with_capacity(self.len()); - for e in self { - unique_index.push(index[&e.to_i64().unwrap()]); + for idx in 0..self.len() { + unique_index.push(index[&self.get(idx).to_i64().unwrap()]); } (unique, unique_index) @@ -30,11 +31,11 @@ mod tests { use super::*; #[test] - fn unique() { + fn unique_with_indices() { let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; assert_eq!( (vec!(0.0, 1.0, 2.0, 4.0), vec!(0, 0, 1, 1, 2, 0, 3)), - v1.unique() + v1.unique_with_indices() ); } } diff --git a/src/metrics/cluster_helpers.rs b/src/metrics/cluster_helpers.rs index dd5bbb3..8d1e17e 100644 --- a/src/metrics/cluster_helpers.rs +++ b/src/metrics/cluster_helpers.rs @@ -7,8 +7,8 @@ pub fn contingency_matrix( labels_true: &Vec, labels_pred: &Vec, ) -> Vec> { - let (classes, class_idx) = labels_true.unique(); - let (clusters, cluster_idx) = labels_pred.unique(); + let (classes, class_idx) = labels_true.unique_with_indices(); + let (clusters, cluster_idx) = labels_pred.unique_with_indices(); let mut contingency_matrix = Vec::with_capacity(classes.len()); diff --git a/src/naive_bayes/mod.rs b/src/naive_bayes/mod.rs index e9ab792..f93d3bf 100644 --- a/src/naive_bayes/mod.rs +++ b/src/naive_bayes/mod.rs @@ -58,10 +58,7 @@ impl, D: NBDistribution> BaseNaiveBayes>(); - let mut y_hat = M::RowVector::zeros(rows); - for (i, prediction) in predictions.iter().enumerate().take(rows) { - y_hat.set(i, *prediction); - } + let y_hat = M::RowVector::from_array(&predictions); Ok(y_hat) } }