feat: adds KNN Regressor

This commit is contained in:
Volodymyr Orlov
2020-08-27 14:17:18 -07:00
parent f73b349f57
commit e5b412451f
4 changed files with 189 additions and 41 deletions
+47 -1
View File
@@ -1 +1,47 @@
pub mod knn;
//! # Nearest Neighbors
use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::cover_tree::CoverTree;
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::math::distance::Distance;
use crate::math::num::FloatExt;
///
pub mod knn_classifier;
pub mod knn_regressor;
#[derive(Serialize, Deserialize, Debug)]
pub enum KNNAlgorithmName {
LinearSearch,
CoverTree,
}
#[derive(Serialize, Deserialize, Debug)]
enum KNNAlgorithm<T: FloatExt, D: Distance<Vec<T>, T>> {
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
CoverTree(CoverTree<Vec<T>, T, D>),
}
impl KNNAlgorithmName {
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(
&self,
data: Vec<Vec<T>>,
distance: D,
) -> KNNAlgorithm<T, D> {
match *self {
KNNAlgorithmName::LinearSearch => {
KNNAlgorithm::LinearSearch(LinearKNNSearch::new(data, distance))
}
KNNAlgorithmName::CoverTree => KNNAlgorithm::CoverTree(CoverTree::new(data, distance)),
}
}
}
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize> {
match *self {
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
}
}
}