feat: adds train/test split function; fixes bug in random forest

This commit is contained in:
Volodymyr Orlov
2020-09-13 16:23:30 -07:00
parent 1920f9cd0b
commit d28f13d849
9 changed files with 187 additions and 10 deletions
+25 -1
View File
@@ -40,7 +40,7 @@
use std::iter::Sum;
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, Scalar, VecStorage, U1};
use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, RowDVector, Scalar, VecStorage, U1};
use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::lu::LUDecomposableMatrix;
@@ -65,6 +65,20 @@ impl<T: RealNumber + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
fn to_vec(&self) -> Vec<T> {
self.row(0).iter().map(|v| *v).collect()
}
fn zeros(len: usize) -> Self {
RowDVector::zeros(len)
}
fn ones(len: usize) -> Self {
BaseVector::fill(len, T::one())
}
fn fill(len: usize, value: T) -> Self {
let mut m = RowDVector::zeros(len);
m.fill(value);
m
}
}
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
@@ -446,6 +460,16 @@ mod tests {
assert_eq!(vec![1., 2., 3.], v.to_vec());
}
#[test]
fn vec_init() {
let zeros: RowDVector<f32> = BaseVector::zeros(3);
let ones: RowDVector<f32> = BaseVector::ones(3);
let twos: RowDVector<f32> = BaseVector::fill(3, 2.);
assert_eq!(zeros, RowDVector::from_vec(vec![0., 0., 0.]));
assert_eq!(ones, RowDVector::from_vec(vec![1., 1., 1.]));
assert_eq!(twos, RowDVector::from_vec(vec![2., 2., 2.]));
}
#[test]
fn get_set_dynamic() {
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);