Another format.

This commit is contained in:
Malte Londschien
2021-10-20 17:04:24 +02:00
parent d239314967
commit 85b9fde9a7
+7 -3
View File
@@ -203,7 +203,11 @@ impl<T: RealNumber> RandomForestRegressor<T> {
trees.push(tree); trees.push(tree);
} }
Ok(RandomForestRegressor { parameters, trees, samples: maybe_all_samples }) Ok(RandomForestRegressor {
parameters,
trees,
samples: maybe_all_samples,
})
} }
/// Predict class for `x` /// Predict class for `x`
@@ -232,7 +236,6 @@ impl<T: RealNumber> RandomForestRegressor<T> {
result / T::from(n_trees).unwrap() result / T::from(n_trees).unwrap()
} }
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> { pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -370,7 +373,8 @@ mod tests {
m: Option::None, m: Option::None,
keep_samples: true, keep_samples: true,
}, },
).unwrap(); )
.unwrap();
let y_hat = regressor.predict(&x).unwrap(); let y_hat = regressor.predict(&x).unwrap();
let y_hat_oob = regressor.predict_oob(&x).unwrap(); let y_hat_oob = regressor.predict_oob(&x).unwrap();