diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index f7723ac..ecf4a48 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -104,6 +104,7 @@ impl RandomForestRegressor { mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + use ndarray::{arr1, arr2}; #[test] fn fit_longley() { @@ -142,4 +143,45 @@ mod tests { } + #[test] + fn my_fit_longley1() { + + let x = arr2(&[ + [ 234.289, 235.6, 159., 107.608, 1947., 60.323], + [ 259.426, 232.5, 145.6, 108.632, 1948., 61.122], + [ 258.054, 368.2, 161.6, 109.773, 1949., 60.171], + [ 284.599, 335.1, 165., 110.929, 1950., 61.187], + [ 328.975, 209.9, 309.9, 112.075, 1951., 63.221], + [ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639], + [ 365.385, 187., 354.7, 115.094, 1953., 64.989], + [ 363.112, 357.8, 335., 116.219, 1954., 63.761], + [ 397.469, 290.4, 304.8, 117.388, 1955., 66.019], + [ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857], + [ 442.769, 293.6, 279.8, 120.445, 1957., 68.169], + [ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513], + [ 482.704, 381.3, 255.2, 123.366, 1959., 68.655], + [ 502.601, 393.1, 251.4, 125.368, 1960., 69.564], + [ 518.173, 480.6, 257.2, 127.852, 1961., 69.331], + [ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); + let y = arr1(&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]); + + println!("{:?}", y.shape()); + + let expected_y: Vec = vec![85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.]; + + let y_hat = RandomForestRegressor::fit(&x, &y, + RandomForestRegressorParameters{max_depth: None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 1000, + mtry: Option::None}).predict(&x); + + println!("{:?}", y_hat); + + for i in 0..y_hat.len() { + assert!((y_hat[i] - expected_y[i]).abs() < 1.0); + } + + } + } \ No newline at end of file diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index d54e386..400e5ea 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -137,6 +137,9 @@ impl BaseMatrix for DenseMatrix { } fn get(&self, row: usize, col: usize) -> T { + if row >= self.nrows || col >= self.ncols { + panic!("Invalid index ({},{}) for {}x{} matrix", row, col, self.nrows, self.ncols); + } self.values[col*self.nrows + row] } diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 20fa0b1..e9808cf 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -94,6 +94,7 @@ impl DecisionTreeRegressor { pub fn fit_weak_learner>(x: &M, y: &M::RowVector, samples: Vec, mtry: usize, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor { let y_m = M::from_row_vector(y.clone()); + // println!("{:?}", y_m); let (_, y_ncols) = y_m.shape(); let (_, num_attributes) = x.shape(); let classes = y_m.unique(); @@ -108,7 +109,7 @@ impl DecisionTreeRegressor { let mut sum = T::zero(); for i in 0..y_ncols { n += samples[i]; - sum = sum + T::from(samples[i]).unwrap() * y_m.get(i, 0); + sum = sum + T::from(samples[i]).unwrap() * y_m.get(0, i); } let root = Node::new(0, sum / T::from(n).unwrap()); @@ -221,7 +222,7 @@ impl DecisionTreeRegressor { if prevx.is_nan() || visitor.x.get(*i, j) == prevx { prevx = visitor.x.get(*i, j); true_count += visitor.samples[*i]; - true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(*i, 0); + true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i); continue; } @@ -230,7 +231,7 @@ impl DecisionTreeRegressor { if true_count < self.parameters.min_samples_leaf || false_count < self.parameters.min_samples_leaf { prevx = visitor.x.get(*i, j); true_count += visitor.samples[*i]; - true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(*i, 0); + true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i); continue; } @@ -248,7 +249,7 @@ impl DecisionTreeRegressor { } prevx = visitor.x.get(*i, j); - true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(*i, 0); + true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i); true_count += visitor.samples[*i]; } } @@ -354,6 +355,35 @@ mod tests { assert!((y_hat[i] - expected_y[i]).abs() < 0.1); } - } + } + + #[test] + fn fit_longley1() { + + let x = DenseMatrix::from_array(&[ + &[ 234.289, 235.6, 159., 107.608, 1947., 60.323], + &[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[ 284.599, 335.1, 165., 110.929, 1950., 61.187], + &[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639], + &[ 365.385, 187., 354.7, 115.094, 1953., 64.989], + &[ 363.112, 357.8, 335., 116.219, 1954., 63.761], + &[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857], + &[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513], + &[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); + let y: Vec = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; + + let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x); + + for i in 0..y_hat.len() { + assert!((y_hat[i] - y[i]).abs() < 0.1); + } + } } \ No newline at end of file