diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 46f09c9..4fb259f 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -83,6 +83,21 @@ pub trait BaseVector: Clone + Debug { self.len() == 0 } + /// Create a new vector from a &[T] + /// ``` + /// use smartcore::linalg::naive::dense_matrix::*; + /// let slice: &[f64] = &[0., 0.5, 2., 3., 4.]; + /// let a: Vec = BaseVector::from_slice(slice); + /// assert_eq!(a, vec![0., 0.5, 2., 3., 4.]); + /// ``` + fn from_slice(f: &[T]) -> Self { + let mut v = Self::zeros(f.len()); + for (i, elem) in f.iter().enumerate() { + v.set(i, *elem); + } + v + } + /// Return a vector with the elements of the one-dimensional array. fn to_vec(&self) -> Vec; diff --git a/src/math/vector.rs b/src/math/vector.rs index accfed6..14e1925 100644 --- a/src/math/vector.rs +++ b/src/math/vector.rs @@ -1,13 +1,14 @@ use crate::math::num::RealNumber; use std::collections::HashMap; +use crate::linalg::BaseVector; pub trait RealNumberVector { fn unique(&self) -> (Vec, Vec); } -impl RealNumberVector for Vec { +impl> RealNumberVector for V { fn unique(&self) -> (Vec, Vec) { - let mut unique = self.clone(); + let mut unique = self.to_vec(); unique.sort_by(|a, b| a.partial_cmp(b).unwrap()); unique.dedup(); @@ -17,8 +18,8 @@ impl RealNumberVector for Vec { } let mut unique_index = Vec::with_capacity(self.len()); - for e in self { - unique_index.push(index[&e.to_i64().unwrap()]); + for idx in 0..self.len() { + unique_index.push(index[&self.get(idx).to_i64().unwrap()]); } (unique, unique_index) @@ -27,7 +28,7 @@ impl RealNumberVector for Vec { #[cfg(test)] mod tests { - use super::*; + use super::RealNumberVector; #[test] fn unique() { diff --git a/src/naive_bayes/mod.rs b/src/naive_bayes/mod.rs index e9ab792..ffc3e2e 100644 --- a/src/naive_bayes/mod.rs +++ b/src/naive_bayes/mod.rs @@ -58,10 +58,7 @@ impl, D: NBDistribution> BaseNaiveBayes>(); - let mut y_hat = M::RowVector::zeros(rows); - for (i, prediction) in predictions.iter().enumerate().take(rows) { - y_hat.set(i, *prediction); - } + let y_hat = M::RowVector::from_slice(&predictions); Ok(y_hat) } }