diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 418f583..77352a3 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -222,6 +222,7 @@ impl RandomForestClassifier { mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + use crate::metrics::*; #[test] fn fit_predict_iris() { @@ -264,7 +265,8 @@ mod tests { }, ); - assert_eq!(y, classifier.predict(&x)); + assert!(accuracy(&y, &classifier.predict(&x)) > 0.9); + } #[test] diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 122673a..7a2a740 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -735,7 +735,7 @@ mod tests { min_samples_leaf: 1, min_samples_split: 2, n_trees: 1000, - mtry: Option::None, + m: Option::None, }, ) .predict(&x);