feat: new distance function parameter in KNN, extends KNN documentation
This commit is contained in:
+30
-1
@@ -52,12 +52,41 @@ pub enum KNNAlgorithmName {
|
||||
CoverTree,
|
||||
}
|
||||
|
||||
/// Weight function that is used to determine estimated value.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum KNNWeightFunction {
|
||||
/// All k nearest points are weighted equally
|
||||
Uniform,
|
||||
/// k nearest points are weighted by the inverse of their distance. Closer neighbors will have a greater influence than neighbors which are further away.
|
||||
Distance,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
enum KNNAlgorithm<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||
}
|
||||
|
||||
impl KNNWeightFunction {
|
||||
fn calc_weights<T: FloatExt>(&self, distances: Vec<T>) -> std::vec::Vec<T> {
|
||||
match *self {
|
||||
KNNWeightFunction::Distance => {
|
||||
// if there are any points that has zero distance from one or more training points,
|
||||
// those training points are weighted as 1.0 and the other points as 0.0
|
||||
if distances.iter().any(|&e| e == T::zero()) {
|
||||
distances
|
||||
.iter()
|
||||
.map(|e| if *e == T::zero() { T::one() } else { T::zero() })
|
||||
.collect()
|
||||
} else {
|
||||
distances.iter().map(|e| T::one() / *e).collect()
|
||||
}
|
||||
}
|
||||
KNNWeightFunction::Uniform => vec![T::one(); distances.len()],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl KNNAlgorithmName {
|
||||
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(
|
||||
&self,
|
||||
@@ -74,7 +103,7 @@ impl KNNAlgorithmName {
|
||||
}
|
||||
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize> {
|
||||
fn find(&self, from: &Vec<T>, k: usize) -> Vec<(usize, T)> {
|
||||
match *self {
|
||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
||||
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
||||
|
||||
Reference in New Issue
Block a user