feat: new distance function parameter in KNN, extends KNN documentation

This commit is contained in:
Volodymyr Orlov
2020-08-28 15:30:52 -07:00
parent dcf636a5f1
commit 367ea62608
6 changed files with 172 additions and 33 deletions
+30 -1
View File
@@ -52,12 +52,41 @@ pub enum KNNAlgorithmName {
CoverTree,
}
/// Weight function that is used to determine estimated value.
#[derive(Serialize, Deserialize, Debug)]
pub enum KNNWeightFunction {
/// All k nearest points are weighted equally
Uniform,
/// k nearest points are weighted by the inverse of their distance. Closer neighbors will have a greater influence than neighbors which are further away.
Distance,
}
#[derive(Serialize, Deserialize, Debug)]
enum KNNAlgorithm<T: FloatExt, D: Distance<Vec<T>, T>> {
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
CoverTree(CoverTree<Vec<T>, T, D>),
}
impl KNNWeightFunction {
fn calc_weights<T: FloatExt>(&self, distances: Vec<T>) -> std::vec::Vec<T> {
match *self {
KNNWeightFunction::Distance => {
// if there are any points that has zero distance from one or more training points,
// those training points are weighted as 1.0 and the other points as 0.0
if distances.iter().any(|&e| e == T::zero()) {
distances
.iter()
.map(|e| if *e == T::zero() { T::one() } else { T::zero() })
.collect()
} else {
distances.iter().map(|e| T::one() / *e).collect()
}
}
KNNWeightFunction::Uniform => vec![T::one(); distances.len()],
}
}
}
impl KNNAlgorithmName {
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(
&self,
@@ -74,7 +103,7 @@ impl KNNAlgorithmName {
}
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize> {
fn find(&self, from: &Vec<T>, k: usize) -> Vec<(usize, T)> {
match *self {
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),