Another format.

This commit is contained in:
Malte Londschien
2021-10-20 17:04:24 +02:00
parent d239314967
commit 85b9fde9a7
+7 -3
View File
@@ -203,7 +203,11 @@ impl<T: RealNumber> RandomForestRegressor<T> {
trees.push(tree);
}
Ok(RandomForestRegressor { parameters, trees, samples: maybe_all_samples })
Ok(RandomForestRegressor {
parameters,
trees,
samples: maybe_all_samples,
})
}
/// Predict class for `x`
@@ -232,7 +236,6 @@ impl<T: RealNumber> RandomForestRegressor<T> {
result / T::from(n_trees).unwrap()
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
@@ -370,7 +373,8 @@ mod tests {
m: Option::None,
keep_samples: true,
},
).unwrap();
)
.unwrap();
let y_hat = regressor.predict(&x).unwrap();
let y_hat_oob = regressor.predict_oob(&x).unwrap();