feat: + cross_validate, trait Predictor, refactoring

This commit is contained in:
Volodymyr Orlov
2020-12-22 15:41:53 -08:00
parent 40dfca702e
commit a2be9e117f
34 changed files with 977 additions and 369 deletions
+72
View File
@@ -274,6 +274,19 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
/// Copies content of `other` vector.
fn copy_from(&mut self, other: &Self);
/// Take elements from an array.
fn take(&self, index: &[usize]) -> Self {
let n = index.len();
let mut result = Self::zeros(n);
for i in 0..n {
result.set(i, self.get(index[i]));
}
result
}
}
/// Generic matrix type.
@@ -611,6 +624,32 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
/// Calculates the covariance matrix
fn cov(&self) -> Self;
/// Take elements from an array along an axis.
fn take(&self, index: &[usize], axis: u8) -> Self {
let (n, p) = self.shape();
let k = match axis {
0 => p,
_ => n,
};
let mut result = match axis {
0 => Self::zeros(index.len(), p),
_ => Self::zeros(n, index.len()),
};
for i in 0..index.len() {
for j in 0..k {
match axis {
0 => result.set(i, j, self.get(index[i], j)),
_ => result.set(j, i, self.get(j, index[i])),
};
}
}
result
}
}
/// Generic matrix with additional mixins like various factorization methods.
@@ -662,6 +701,8 @@ impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
#[cfg(test)]
mod tests {
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::BaseMatrix;
use crate::linalg::BaseVector;
#[test]
@@ -684,4 +725,35 @@ mod tests {
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
}
#[test]
fn vec_take() {
let m = vec![1., 2., 3., 4., 5.];
assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]);
}
#[test]
fn take() {
let m = DenseMatrix::from_2d_array(&[
&[1.0, 2.0],
&[3.0, 4.0],
&[5.0, 6.0],
&[7.0, 8.0],
&[9.0, 10.0],
]);
let expected_0 = DenseMatrix::from_2d_array(&[&[3.0, 4.0], &[3.0, 4.0], &[7.0, 8.0]]);
let expected_1 = DenseMatrix::from_2d_array(&[
&[2.0, 1.0],
&[4.0, 3.0],
&[6.0, 5.0],
&[8.0, 7.0],
&[10.0, 9.0],
]);
assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0);
assert_eq!(m.take(&vec!(1, 0), 1), expected_1);
}
}
+2 -2
View File
@@ -36,7 +36,7 @@
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
//! ]);
//!
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
//! let y_hat = lr.predict(&x).unwrap();
//! ```
use std::iter::Sum;
@@ -917,7 +917,7 @@ mod tests {
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]);
let lr = LogisticRegression::fit(&x, &y).unwrap();
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
let y_hat = lr.predict(&x).unwrap();