feat: extends interface of Matrix to support for broad range of types

This commit is contained in:
Volodymyr Orlov
2020-03-26 15:28:26 -07:00
parent 84ffd331cd
commit 02b85415d9
27 changed files with 1021 additions and 868 deletions
+13 -13
View File
@@ -1,21 +1,21 @@
use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::{Matrix, row_iter};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::algorithm::neighbour::cover_tree::CoverTree;
type F = dyn Fn(&Vec<f64>, &Vec<f64>) -> f64;
pub struct KNNClassifier<'a> {
classes: Vec<f64>,
pub struct KNNClassifier<'a, T: FloatExt> {
classes: Vec<T>,
y: Vec<usize>,
knn_algorithm: Box<dyn KNNAlgorithm<Vec<f64>> + 'a>,
knn_algorithm: Box<dyn KNNAlgorithm<Vec<T>> + 'a>,
k: usize,
}
impl<'a> KNNClassifier<'a> {
impl<'a, T: FloatExt + Debug> KNNClassifier<'a, T> {
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, k: usize, distance: &'a F, algorithm: KNNAlgorithmName) -> KNNClassifier<'a> {
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, k: usize, distance: &'a dyn Fn(&Vec<T>, &Vec<T>) -> T, algorithm: KNNAlgorithmName) -> KNNClassifier<'a, T> {
let y_m = M::from_row_vector(y.clone());
@@ -36,16 +36,16 @@ impl<'a> KNNClassifier<'a> {
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
let knn_algorithm: Box<dyn KNNAlgorithm<Vec<f64>> + 'a> = match algorithm {
KNNAlgorithmName::CoverTree => Box::new(CoverTree::<Vec<f64>>::new(data, distance)),
KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::<Vec<f64>>::new(data, distance))
let knn_algorithm: Box<dyn KNNAlgorithm<Vec<T>> + 'a> = match algorithm {
KNNAlgorithmName::CoverTree => Box::new(CoverTree::<Vec<T>, T>::new(data, distance)),
KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::<Vec<T>, T>::new(data, distance))
};
KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: knn_algorithm}
}
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector {
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0);
row_iter(x).enumerate().for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)]));
@@ -53,7 +53,7 @@ impl<'a> KNNClassifier<'a> {
result.to_row_vector()
}
fn predict_for_row(&self, x: Vec<f64>) -> usize {
fn predict_for_row(&self, x: Vec<T>) -> usize {
let idxs = self.knn_algorithm.find(&x, self.k);
let mut c = vec![0; self.classes.len()];