Another format.
This commit is contained in:
@@ -203,7 +203,11 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
Ok(RandomForestRegressor { parameters, trees, samples: maybe_all_samples })
|
||||
Ok(RandomForestRegressor {
|
||||
parameters,
|
||||
trees,
|
||||
samples: maybe_all_samples,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict class for `x`
|
||||
@@ -232,7 +236,6 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
result / T::from(n_trees).unwrap()
|
||||
}
|
||||
|
||||
|
||||
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
@@ -370,7 +373,8 @@ mod tests {
|
||||
m: Option::None,
|
||||
keep_samples: true,
|
||||
},
|
||||
).unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let y_hat = regressor.predict(&x).unwrap();
|
||||
let y_hat_oob = regressor.predict_oob(&x).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user