fix: fixes a bug in decision_tree_regressor

This commit is contained in:
Volodymyr Orlov
2020-03-26 16:21:20 -07:00
parent 02b85415d9
commit 4d967858a5
3 changed files with 80 additions and 5 deletions
+42
View File
@@ -104,6 +104,7 @@ impl<T: FloatExt + Debug> RandomForestRegressor<T> {
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
use ndarray::{arr1, arr2};
#[test] #[test]
fn fit_longley() { fn fit_longley() {
@@ -142,4 +143,45 @@ mod tests {
} }
#[test]
fn my_fit_longley1() {
let x = arr2(&[
[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
let y = arr1(&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]);
println!("{:?}", y.shape());
let expected_y: Vec<f64> = vec![85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.];
let y_hat = RandomForestRegressor::fit(&x, &y,
RandomForestRegressorParameters{max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
mtry: Option::None}).predict(&x);
println!("{:?}", y_hat);
for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
}
}
} }
+3
View File
@@ -137,6 +137,9 @@ impl<T: FloatExt + Debug> BaseMatrix<T> for DenseMatrix<T> {
} }
fn get(&self, row: usize, col: usize) -> T { fn get(&self, row: usize, col: usize) -> T {
if row >= self.nrows || col >= self.ncols {
panic!("Invalid index ({},{}) for {}x{} matrix", row, col, self.nrows, self.ncols);
}
self.values[col*self.nrows + row] self.values[col*self.nrows + row]
} }
+34 -4
View File
@@ -94,6 +94,7 @@ impl<T: FloatExt + Debug> DecisionTreeRegressor<T> {
pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor<T> { pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor<T> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
// println!("{:?}", y_m);
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
let classes = y_m.unique(); let classes = y_m.unique();
@@ -108,7 +109,7 @@ impl<T: FloatExt + Debug> DecisionTreeRegressor<T> {
let mut sum = T::zero(); let mut sum = T::zero();
for i in 0..y_ncols { for i in 0..y_ncols {
n += samples[i]; n += samples[i];
sum = sum + T::from(samples[i]).unwrap() * y_m.get(i, 0); sum = sum + T::from(samples[i]).unwrap() * y_m.get(0, i);
} }
let root = Node::new(0, sum / T::from(n).unwrap()); let root = Node::new(0, sum / T::from(n).unwrap());
@@ -221,7 +222,7 @@ impl<T: FloatExt + Debug> DecisionTreeRegressor<T> {
if prevx.is_nan() || visitor.x.get(*i, j) == prevx { if prevx.is_nan() || visitor.x.get(*i, j) == prevx {
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
true_count += visitor.samples[*i]; true_count += visitor.samples[*i];
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(*i, 0); true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
continue; continue;
} }
@@ -230,7 +231,7 @@ impl<T: FloatExt + Debug> DecisionTreeRegressor<T> {
if true_count < self.parameters.min_samples_leaf || false_count < self.parameters.min_samples_leaf { if true_count < self.parameters.min_samples_leaf || false_count < self.parameters.min_samples_leaf {
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
true_count += visitor.samples[*i]; true_count += visitor.samples[*i];
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(*i, 0); true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
continue; continue;
} }
@@ -248,7 +249,7 @@ impl<T: FloatExt + Debug> DecisionTreeRegressor<T> {
} }
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(*i, 0); true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
true_count += visitor.samples[*i]; true_count += visitor.samples[*i];
} }
} }
@@ -356,4 +357,33 @@ mod tests {
} }
#[test]
fn fit_longley1() {
let x = DenseMatrix::from_array(&[
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
for i in 0..y_hat.len() {
assert!((y_hat[i] - y[i]).abs() < 0.1);
}
}
} }