Another format.
This commit is contained in:
@@ -203,7 +203,11 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(RandomForestRegressor { parameters, trees, samples: maybe_all_samples })
|
Ok(RandomForestRegressor {
|
||||||
|
parameters,
|
||||||
|
trees,
|
||||||
|
samples: maybe_all_samples,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict class for `x`
|
/// Predict class for `x`
|
||||||
@@ -232,7 +236,6 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
result / T::from(n_trees).unwrap()
|
result / T::from(n_trees).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||||
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
@@ -370,7 +373,8 @@ mod tests {
|
|||||||
m: Option::None,
|
m: Option::None,
|
||||||
keep_samples: true,
|
keep_samples: true,
|
||||||
},
|
},
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let y_hat = regressor.predict(&x).unwrap();
|
let y_hat = regressor.predict(&x).unwrap();
|
||||||
let y_hat_oob = regressor.predict_oob(&x).unwrap();
|
let y_hat_oob = regressor.predict_oob(&x).unwrap();
|
||||||
|
|||||||
Reference in New Issue
Block a user