feat: adds serialization/deserialization methods
This commit is contained in:
+2
-1
@@ -4,11 +4,12 @@ extern crate smartcore;
|
|||||||
|
|
||||||
use criterion::Criterion;
|
use criterion::Criterion;
|
||||||
use criterion::black_box;
|
use criterion::black_box;
|
||||||
|
use smartcore::math::distance::euclidian::*;
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
let a = vec![1., 2., 3.];
|
let a = vec![1., 2., 3.];
|
||||||
|
|
||||||
c.bench_function("Euclidean Distance", move |b| b.iter(|| smartcore::math::distance::euclidian::distance(black_box(&a), black_box(&a))));
|
c.bench_function("Euclidean Distance", move |b| b.iter(|| Euclidian::distance(black_box(&a), black_box(&a))));
|
||||||
}
|
}
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
criterion_group!(benches, criterion_benchmark);
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::fmt::Debug;
|
|||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::distance::euclidian;
|
use crate::math::distance::euclidian::*;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct BBDTree<T: FloatExt> {
|
pub struct BBDTree<T: FloatExt> {
|
||||||
@@ -79,10 +79,10 @@ impl<T: FloatExt> BBDTree<T> {
|
|||||||
let d = centroids[0].len();
|
let d = centroids[0].len();
|
||||||
|
|
||||||
// Determine which mean the node mean is closest to
|
// Determine which mean the node mean is closest to
|
||||||
let mut min_dist = euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
let mut min_dist = Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||||
let mut closest = candidates[0];
|
let mut closest = candidates[0];
|
||||||
for i in 1..k {
|
for i in 1..k {
|
||||||
let dist = euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
let dist = Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||||
if dist < min_dist {
|
if dist < min_dist {
|
||||||
min_dist = dist;
|
min_dist = dist;
|
||||||
closest = candidates[i];
|
closest = candidates[i];
|
||||||
|
|||||||
@@ -3,25 +3,26 @@ use std::iter::FromIterator;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use core::hash::{Hash, Hasher};
|
use core::hash::{Hash, Hasher};
|
||||||
|
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::algorithm::neighbour::KNNAlgorithm;
|
use crate::math::distance::Distance;
|
||||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||||
|
|
||||||
pub struct CoverTree<'a, T, F: FloatExt>
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
where T: Debug
|
pub struct CoverTree<T, F: FloatExt, D: Distance<T, F>>
|
||||||
{
|
{
|
||||||
base: F,
|
base: F,
|
||||||
max_level: i8,
|
max_level: i8,
|
||||||
min_level: i8,
|
min_level: i8,
|
||||||
distance: &'a dyn Fn(&T, &T) -> F,
|
distance: D,
|
||||||
nodes: Vec<Node<T>>
|
nodes: Vec<Node<T>>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T, F: FloatExt> CoverTree<'a, T, F>
|
impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
||||||
where T: Debug
|
|
||||||
{
|
{
|
||||||
|
|
||||||
pub fn new(mut data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> CoverTree<T, F> {
|
pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
|
||||||
let mut tree = CoverTree {
|
let mut tree = CoverTree {
|
||||||
base: F::two(),
|
base: F::two(),
|
||||||
max_level: 100,
|
max_level: 100,
|
||||||
@@ -43,7 +44,7 @@ where T: Debug
|
|||||||
} else {
|
} else {
|
||||||
let mut parent: Option<NodeId> = Option::None;
|
let mut parent: Option<NodeId> = Option::None;
|
||||||
let mut p_i = 0;
|
let mut p_i = 0;
|
||||||
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
|
let mut qi_p_ds = vec!((self.root(), D::distance(&p, &self.root().data)));
|
||||||
let mut i = self.max_level;
|
let mut i = self.max_level;
|
||||||
loop {
|
loop {
|
||||||
let i_d = self.base.powf(F::from(i).unwrap());
|
let i_d = self.base.powf(F::from(i).unwrap());
|
||||||
@@ -82,6 +83,18 @@ where T: Debug
|
|||||||
node_id
|
node_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn find(&self, p: &T, k: usize) -> Vec<usize>{
|
||||||
|
let mut qi_p_ds = vec!((self.root(), D::distance(&p, &self.root().data)));
|
||||||
|
for i in (self.min_level..self.max_level+1).rev() {
|
||||||
|
let i_d = self.base.powf(F::from(i).unwrap());
|
||||||
|
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||||
|
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
||||||
|
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
|
||||||
|
}
|
||||||
|
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||||
|
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
|
||||||
|
}
|
||||||
|
|
||||||
fn split(&self, p_id: NodeId, r: F, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){
|
fn split(&self, p_id: NodeId, r: F, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){
|
||||||
|
|
||||||
let mut my_near = (Vec::new(), Vec::new());
|
let mut my_near = (Vec::new(), Vec::new());
|
||||||
@@ -102,7 +115,7 @@ where T: Debug
|
|||||||
let p = &self.nodes.get(p_id.index).unwrap().data;
|
let p = &self.nodes.get(p_id.index).unwrap().data;
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
while i != s.len() {
|
while i != s.len() {
|
||||||
let d = (self.distance)(p, &s[i]);
|
let d = D::distance(p, &s[i]);
|
||||||
if d <= r {
|
if d <= r {
|
||||||
my_near.0.push(s.remove(i));
|
my_near.0.push(s.remove(i));
|
||||||
} else if d > r && d <= F::two() * r{
|
} else if d > r && d <= F::two() * r{
|
||||||
@@ -156,7 +169,7 @@ where T: Debug
|
|||||||
|
|
||||||
let q: Vec<&Node<T>> = qi_p_ds.iter().flat_map(|(n, _)| self.get_child(n, i)).collect();
|
let q: Vec<&Node<T>> = qi_p_ds.iter().flat_map(|(n, _)| self.get_child(n, i)).collect();
|
||||||
|
|
||||||
children.extend(q.into_iter().map(|n| (n, (self.distance)(&n.data, &p))));
|
children.extend(q.into_iter().map(|n| (n, D::distance(&n.data, &p))));
|
||||||
|
|
||||||
children
|
children
|
||||||
|
|
||||||
@@ -180,7 +193,7 @@ where T: Debug
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn check_invariant(&self, invariant: fn(&CoverTree<T, F>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
|
fn check_invariant(&self, invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
|
||||||
let mut current_nodes: Vec<&Node<T>> = Vec::new();
|
let mut current_nodes: Vec<&Node<T>> = Vec::new();
|
||||||
current_nodes.push(self.root());
|
current_nodes.push(self.root());
|
||||||
for i in (self.min_level..self.max_level+1).rev() {
|
for i in (self.min_level..self.max_level+1).rev() {
|
||||||
@@ -193,7 +206,7 @@ where T: Debug
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn nesting_invariant(_: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
|
fn nesting_invariant(_: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
|
||||||
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
|
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
|
||||||
let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
|
let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
|
||||||
for n in nodes_set.iter() {
|
for n in nodes_set.iter() {
|
||||||
@@ -202,11 +215,11 @@ where T: Debug
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn covering_tree(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
|
fn covering_tree(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
|
||||||
let mut p_selected: Vec<&Node<T>> = Vec::new();
|
let mut p_selected: Vec<&Node<T>> = Vec::new();
|
||||||
for p in next_nodes {
|
for p in next_nodes {
|
||||||
for q in nodes {
|
for q in nodes {
|
||||||
if (tree.distance)(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
|
if D::distance(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
|
||||||
p_selected.push(*p);
|
p_selected.push(*p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -216,11 +229,11 @@ where T: Debug
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn separation(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
fn separation(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
||||||
for p in nodes {
|
for p in nodes {
|
||||||
for q in nodes {
|
for q in nodes {
|
||||||
if p != q {
|
if p != q {
|
||||||
assert!((tree.distance)(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
|
assert!(D::distance(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -228,28 +241,12 @@ where T: Debug
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T, F: FloatExt> KNNAlgorithm<T> for CoverTree<'a, T, F>
|
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||||
where T: Debug
|
|
||||||
{
|
|
||||||
fn find(&self, p: &T, k: usize) -> Vec<usize>{
|
|
||||||
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
|
|
||||||
for i in (self.min_level..self.max_level+1).rev() {
|
|
||||||
let i_d = self.base.powf(F::from(i).unwrap());
|
|
||||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
|
||||||
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
|
||||||
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
|
|
||||||
}
|
|
||||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
|
||||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
|
||||||
pub struct NodeId {
|
pub struct NodeId {
|
||||||
index: usize,
|
index: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct Node<T> {
|
struct Node<T> {
|
||||||
index: NodeId,
|
index: NodeId,
|
||||||
data: T,
|
data: T,
|
||||||
@@ -280,13 +277,19 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
struct SimpleDistance{}
|
||||||
|
|
||||||
|
impl Distance<i32, f64> for SimpleDistance {
|
||||||
|
fn distance(a: &i32, b: &i32) -> f64 {
|
||||||
|
(a - b).abs() as f64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cover_tree_test() {
|
fn cover_tree_test() {
|
||||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
||||||
let distance = |a: &i32, b: &i32| -> f64 {
|
|
||||||
(a - b).abs() as f64
|
let mut tree = CoverTree::new(data, SimpleDistance{});
|
||||||
};
|
|
||||||
let mut tree = CoverTree::<i32, f64>::new(data, &distance);
|
|
||||||
for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) {
|
for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) {
|
||||||
tree.insert(d);
|
tree.insert(d);
|
||||||
}
|
}
|
||||||
@@ -306,10 +309,8 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_invariants(){
|
fn test_invariants(){
|
||||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
||||||
let distance = |a: &i32, b: &i32| -> f64 {
|
|
||||||
(a - b).abs() as f64
|
let tree = CoverTree::new(data, SimpleDistance{});
|
||||||
};
|
|
||||||
let tree = CoverTree::<i32, f64>::new(data, &distance);
|
|
||||||
tree.check_invariant(CoverTree::nesting_invariant);
|
tree.check_invariant(CoverTree::nesting_invariant);
|
||||||
tree.check_invariant(CoverTree::covering_tree);
|
tree.check_invariant(CoverTree::covering_tree);
|
||||||
tree.check_invariant(CoverTree::separation);
|
tree.check_invariant(CoverTree::separation);
|
||||||
|
|||||||
@@ -1,16 +1,28 @@
|
|||||||
use crate::algorithm::neighbour::KNNAlgorithm;
|
|
||||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
|
||||||
use std::cmp::{Ordering, PartialOrd};
|
use std::cmp::{Ordering, PartialOrd};
|
||||||
use num_traits::Float;
|
use std::marker::PhantomData;
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
pub struct LinearKNNSearch<'a, T, F: Float> {
|
use crate::math::num::FloatExt;
|
||||||
distance: Box<dyn Fn(&T, &T) -> F + 'a>,
|
use crate::math::distance::Distance;
|
||||||
data: Vec<T>
|
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct LinearKNNSearch<T, F: FloatExt, D: Distance<T, F>> {
|
||||||
|
distance: D,
|
||||||
|
data: Vec<T>,
|
||||||
|
f: PhantomData<F>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
|
impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||||
{
|
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D>{
|
||||||
fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
LinearKNNSearch{
|
||||||
|
data: data,
|
||||||
|
distance: distance,
|
||||||
|
f: PhantomData
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
||||||
if k < 1 || k > self.data.len() {
|
if k < 1 || k > self.data.len() {
|
||||||
panic!("k should be >= 1 and <= length(data)");
|
panic!("k should be >= 1 and <= length(data)");
|
||||||
}
|
}
|
||||||
@@ -19,14 +31,14 @@ impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
|
|||||||
|
|
||||||
for _ in 0..k {
|
for _ in 0..k {
|
||||||
heap.add(KNNPoint{
|
heap.add(KNNPoint{
|
||||||
distance: Float::infinity(),
|
distance: F::infinity(),
|
||||||
index: None
|
index: None
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
for i in 0..self.data.len() {
|
for i in 0..self.data.len() {
|
||||||
|
|
||||||
let d = (self.distance)(&from, &self.data[i]);
|
let d = D::distance(&from, &self.data[i]);
|
||||||
let datum = heap.peek_mut();
|
let datum = heap.peek_mut();
|
||||||
if d < datum.distance {
|
if d < datum.distance {
|
||||||
datum.distance = d;
|
datum.distance = d;
|
||||||
@@ -41,43 +53,34 @@ impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T, F: Float> LinearKNNSearch<'a, T, F> {
|
|
||||||
pub fn new(data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> LinearKNNSearch<T, F>{
|
|
||||||
LinearKNNSearch{
|
|
||||||
data: data,
|
|
||||||
distance: Box::new(distance)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct KNNPoint<F: Float> {
|
struct KNNPoint<F: FloatExt> {
|
||||||
distance: F,
|
distance: F,
|
||||||
index: Option<usize>
|
index: Option<usize>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: Float> PartialOrd for KNNPoint<F> {
|
impl<F: FloatExt> PartialOrd for KNNPoint<F> {
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
self.distance.partial_cmp(&other.distance)
|
self.distance.partial_cmp(&other.distance)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: Float> PartialEq for KNNPoint<F> {
|
impl<F: FloatExt> PartialEq for KNNPoint<F> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.distance == other.distance
|
self.distance == other.distance
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: Float> Eq for KNNPoint<F> {}
|
impl<F: FloatExt> Eq for KNNPoint<F> {}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::math::distance::euclidian;
|
use crate::math::distance::Distances;
|
||||||
|
|
||||||
struct SimpleDistance{}
|
struct SimpleDistance{}
|
||||||
|
|
||||||
impl SimpleDistance {
|
impl Distance<i32, f64> for SimpleDistance {
|
||||||
fn distance(a: &i32, b: &i32) -> f64 {
|
fn distance(a: &i32, b: &i32) -> f64 {
|
||||||
(a - b).abs() as f64
|
(a - b).abs() as f64
|
||||||
}
|
}
|
||||||
@@ -87,13 +90,13 @@ mod tests {
|
|||||||
fn knn_find() {
|
fn knn_find() {
|
||||||
let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||||
|
|
||||||
let algorithm1 = LinearKNNSearch::new(data1, &SimpleDistance::distance);
|
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance{});
|
||||||
|
|
||||||
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
||||||
|
|
||||||
let data2 = vec!(vec![1., 1.], vec![2., 2.], vec![3., 3.], vec![4., 4.], vec![5., 5.]);
|
let data2 = vec!(vec![1., 1.], vec![2., 2.], vec![3., 3.], vec![4., 4.], vec![5., 5.]);
|
||||||
|
|
||||||
let algorithm2 = LinearKNNSearch::new(data2, &euclidian::distance);
|
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
||||||
|
|
||||||
assert_eq!(vec!(2, 3, 1), algorithm2.find(&vec![3., 3.], 3));
|
assert_eq!(vec!(2, 3, 1), algorithm2.find(&vec![3., 3.], 3));
|
||||||
}
|
}
|
||||||
@@ -116,7 +119,7 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let point_inf = KNNPoint{
|
let point_inf = KNNPoint{
|
||||||
distance: Float::infinity(),
|
distance: std::f64::INFINITY,
|
||||||
index: Some(3)
|
index: Some(3)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,3 @@
|
|||||||
pub mod cover_tree;
|
pub mod cover_tree;
|
||||||
pub mod linear_search;
|
pub mod linear_search;
|
||||||
pub mod bbd_tree;
|
pub mod bbd_tree;
|
||||||
|
|
||||||
pub enum KNNAlgorithmName {
|
|
||||||
CoverTree,
|
|
||||||
LinearSearch,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait KNNAlgorithm<T>{
|
|
||||||
fn find(&self, from: &T, k: usize) -> Vec<usize>;
|
|
||||||
}
|
|
||||||
@@ -8,7 +8,7 @@ use serde::{Serialize, Deserialize};
|
|||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::distance::euclidian;
|
use crate::math::distance::euclidian::*;
|
||||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
@@ -130,7 +130,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
|||||||
let mut best_cluster = 0;
|
let mut best_cluster = 0;
|
||||||
|
|
||||||
for j in 0..self.k {
|
for j in 0..self.k {
|
||||||
let dist = euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
|
let dist = Euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
|
||||||
if dist < min_dist {
|
if dist < min_dist {
|
||||||
min_dist = dist;
|
min_dist = dist;
|
||||||
best_cluster = j;
|
best_cluster = j;
|
||||||
@@ -156,7 +156,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
|||||||
// the distance from each sample to its closest center in scores.
|
// the distance from each sample to its closest center in scores.
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
// compute the distance between this sample and the current center
|
// compute the distance between this sample and the current center
|
||||||
let dist = euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||||
|
|
||||||
if dist < d[i] {
|
if dist < d[i] {
|
||||||
d[i] = dist;
|
d[i] = dist;
|
||||||
@@ -183,7 +183,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
|||||||
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
// compute the distance between this sample and the current center
|
// compute the distance between this sample and the current center
|
||||||
let dist = euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||||
|
|
||||||
if dist < d[i] {
|
if dist < d[i] {
|
||||||
d[i] = dist;
|
d[i] = dist;
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::{Matrix};
|
use crate::linalg::{Matrix};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct PCA<T: FloatExt, M: Matrix<T>> {
|
pub struct PCA<T: FloatExt, M: Matrix<T>> {
|
||||||
eigenvectors: M,
|
eigenvectors: M,
|
||||||
eigenvalues: Vec<T>,
|
eigenvalues: Vec<T>,
|
||||||
@@ -11,6 +14,22 @@ pub struct PCA<T: FloatExt, M: Matrix<T>> {
|
|||||||
pmu: Vec<T>
|
pmu: Vec<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt, M: Matrix<T>> PartialEq for PCA<T, M> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.eigenvectors != other.eigenvectors ||
|
||||||
|
self.eigenvalues.len() != other.eigenvalues.len() {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
for i in 0..self.eigenvalues.len() {
|
||||||
|
if (self.eigenvalues[i] - other.eigenvalues[i]).abs() > T::epsilon() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct PCAParameters {
|
pub struct PCAParameters {
|
||||||
use_correlation_matrix: bool
|
use_correlation_matrix: bool
|
||||||
@@ -367,4 +386,36 @@ mod tests {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde() {
|
||||||
|
let iris = DenseMatrix::from_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
|
&[5.2, 2.7, 3.9, 1.4]]);
|
||||||
|
|
||||||
|
let pca = PCA::new(&iris, 4, Default::default());
|
||||||
|
|
||||||
|
let deserialized_pca: PCA<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(pca, deserialized_pca);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -4,12 +4,13 @@ use std::default::Default;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
|
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct RandomForestClassifierParameters {
|
pub struct RandomForestClassifierParameters {
|
||||||
pub criterion: SplitCriterion,
|
pub criterion: SplitCriterion,
|
||||||
pub max_depth: Option<u16>,
|
pub max_depth: Option<u16>,
|
||||||
@@ -19,13 +20,34 @@ pub struct RandomForestClassifierParameters {
|
|||||||
pub mtry: Option<usize>
|
pub mtry: Option<usize>
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct RandomForestClassifier<T: FloatExt> {
|
pub struct RandomForestClassifier<T: FloatExt> {
|
||||||
parameters: RandomForestClassifierParameters,
|
parameters: RandomForestClassifierParameters,
|
||||||
trees: Vec<DecisionTreeClassifier<T>>,
|
trees: Vec<DecisionTreeClassifier<T>>,
|
||||||
classes: Vec<T>
|
classes: Vec<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt> PartialEq for RandomForestClassifier<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.classes.len() != other.classes.len() ||
|
||||||
|
self.trees.len() != other.trees.len() {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
for i in 0..self.classes.len() {
|
||||||
|
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i in 0..self.trees.len() {
|
||||||
|
if self.trees[i] != other.trees[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for RandomForestClassifierParameters {
|
impl Default for RandomForestClassifierParameters {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
RandomForestClassifierParameters {
|
RandomForestClassifierParameters {
|
||||||
@@ -171,4 +193,37 @@ mod tests {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde() {
|
||||||
|
let x = DenseMatrix::from_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
|
&[5.2, 2.7, 3.9, 1.4]]);
|
||||||
|
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||||
|
|
||||||
|
let forest = RandomForestClassifier::fit(&x, &y, Default::default());
|
||||||
|
|
||||||
|
let deserialized_forest: RandomForestClassifier<f64> = bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(forest, deserialized_forest);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -4,12 +4,13 @@ use std::default::Default;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::tree::decision_tree_regressor::{DecisionTreeRegressor, DecisionTreeRegressorParameters};
|
use crate::tree::decision_tree_regressor::{DecisionTreeRegressor, DecisionTreeRegressorParameters};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct RandomForestRegressorParameters {
|
pub struct RandomForestRegressorParameters {
|
||||||
pub max_depth: Option<u16>,
|
pub max_depth: Option<u16>,
|
||||||
pub min_samples_leaf: usize,
|
pub min_samples_leaf: usize,
|
||||||
@@ -18,7 +19,7 @@ pub struct RandomForestRegressorParameters {
|
|||||||
pub mtry: Option<usize>
|
pub mtry: Option<usize>
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct RandomForestRegressor<T: FloatExt> {
|
pub struct RandomForestRegressor<T: FloatExt> {
|
||||||
parameters: RandomForestRegressorParameters,
|
parameters: RandomForestRegressorParameters,
|
||||||
trees: Vec<DecisionTreeRegressor<T>>
|
trees: Vec<DecisionTreeRegressor<T>>
|
||||||
@@ -36,6 +37,21 @@ impl Default for RandomForestRegressorParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt> PartialEq for RandomForestRegressor<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.trees.len() != other.trees.len() {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
for i in 0..self.trees.len() {
|
||||||
|
if self.trees[i] != other.trees[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> RandomForestRegressor<T> {
|
impl<T: FloatExt> RandomForestRegressor<T> {
|
||||||
|
|
||||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: RandomForestRegressorParameters) -> RandomForestRegressor<T> {
|
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: RandomForestRegressorParameters) -> RandomForestRegressor<T> {
|
||||||
@@ -180,4 +196,33 @@ mod tests {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde() {
|
||||||
|
let x = DenseMatrix::from_array(&[
|
||||||
|
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
|
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
|
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||||
|
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||||
|
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||||
|
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
|
||||||
|
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||||
|
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||||
|
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||||
|
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
|
||||||
|
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||||
|
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
|
||||||
|
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
|
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
|
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
|
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||||
|
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||||
|
|
||||||
|
let forest = RandomForestRegressor::fit(&x, &y, Default::default());
|
||||||
|
|
||||||
|
let deserialized_forest: RandomForestRegressor<f64> = bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(forest, deserialized_forest);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -21,7 +21,7 @@ pub struct LinearRegression<T: FloatExt, M: Matrix<T>> {
|
|||||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.coefficients == other.coefficients &&
|
self.coefficients == other.coefficients &&
|
||||||
self.intercept == other.intercept
|
(self.intercept - other.intercept).abs() <= T::epsilon()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,10 +42,19 @@ struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
|||||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
|
||||||
self.num_classes == other.num_classes &&
|
if self.num_classes != other.num_classes ||
|
||||||
self.classes == other.classes &&
|
self.num_attributes != other.num_attributes ||
|
||||||
self.num_attributes == other.num_attributes &&
|
self.classes.len() != other.classes.len() {
|
||||||
self.weights == other.weights
|
return false
|
||||||
|
} else {
|
||||||
|
for i in 0..self.classes.len() {
|
||||||
|
if (self.classes[i] - other.classes[i]).abs() > T::epsilon(){
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.weights == other.weights
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,39 @@
|
|||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
pub fn distance<T: FloatExt>(x: &Vec<T>, y: &Vec<T>) -> T {
|
use super::Distance;
|
||||||
return squared_distance(x, y).sqrt();
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct Euclidian {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn squared_distance<T: FloatExt>(x: &Vec<T>,y: &Vec<T>) -> T {
|
impl Euclidian {
|
||||||
if x.len() != y.len() {
|
pub fn squared_distance<T: FloatExt>(x: &Vec<T>,y: &Vec<T>) -> T {
|
||||||
panic!("Input vector sizes are different.");
|
if x.len() != y.len() {
|
||||||
|
panic!("Input vector sizes are different.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sum = T::zero();
|
||||||
|
for i in 0..x.len() {
|
||||||
|
sum = sum + (x[i] - y[i]).powf(T::two());
|
||||||
|
}
|
||||||
|
|
||||||
|
sum
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut sum = T::zero();
|
pub fn distance<T: FloatExt>(x: &Vec<T>, y: &Vec<T>) -> T {
|
||||||
for i in 0..x.len() {
|
Euclidian::squared_distance(x, y).sqrt()
|
||||||
sum = sum + (x[i] - y[i]).powf(T::two());
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt> Distance<Vec<T>, T> for Euclidian {
|
||||||
|
|
||||||
|
fn distance(x: &Vec<T>, y: &Vec<T>) -> T {
|
||||||
|
Self::distance(x, y)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sum;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -27,7 +46,7 @@ mod tests {
|
|||||||
let a = vec![1., 2., 3.];
|
let a = vec![1., 2., 3.];
|
||||||
let b = vec![4., 5., 6.];
|
let b = vec![4., 5., 6.];
|
||||||
|
|
||||||
let d_arr: f64 = distance(&a, &b);
|
let d_arr: f64 = Euclidian::distance(&a, &b);
|
||||||
|
|
||||||
assert!((d_arr - 5.19615242).abs() < 1e-8);
|
assert!((d_arr - 5.19615242).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1 +1,16 @@
|
|||||||
pub mod euclidian;
|
pub mod euclidian;
|
||||||
|
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
|
pub trait Distance<T, F: FloatExt>{
|
||||||
|
fn distance(a: &T, b: &T) -> F;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Distances{
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Distances {
|
||||||
|
pub fn euclidian() -> euclidian::Euclidian{
|
||||||
|
euclidian::Euclidian {}
|
||||||
|
}
|
||||||
|
}
|
||||||
+84
-14
@@ -1,19 +1,76 @@
|
|||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
use crate::math::distance::Distance;
|
||||||
use crate::linalg::{Matrix, row_iter};
|
use crate::linalg::{Matrix, row_iter};
|
||||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
|
||||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||||
|
|
||||||
pub struct KNNClassifier<'a, T: FloatExt> {
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
knn_algorithm: Box<dyn KNNAlgorithm<Vec<T>> + 'a>,
|
knn_algorithm: KNNAlgorithmV<T, D>,
|
||||||
k: usize,
|
k: usize
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: FloatExt> KNNClassifier<'a, T> {
|
pub enum KNNAlgorithmName {
|
||||||
|
LinearSearch,
|
||||||
|
CoverTree
|
||||||
|
}
|
||||||
|
|
||||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, k: usize, distance: &'a dyn Fn(&Vec<T>, &Vec<T>) -> T, algorithm: KNNAlgorithmName) -> KNNClassifier<'a, T> {
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub enum KNNAlgorithmV<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||||
|
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||||
|
CoverTree(CoverTree<Vec<T>, T, D>)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KNNAlgorithmName {
|
||||||
|
|
||||||
|
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(&self, data: Vec<Vec<T>>, distance: D) -> KNNAlgorithmV<T, D> {
|
||||||
|
match *self {
|
||||||
|
KNNAlgorithmName::LinearSearch => KNNAlgorithmV::LinearSearch(LinearKNNSearch::new(data, distance)),
|
||||||
|
KNNAlgorithmName::CoverTree => KNNAlgorithmV::CoverTree(CoverTree::new(data, distance)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithmV<T, D> {
|
||||||
|
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize>{
|
||||||
|
match *self {
|
||||||
|
KNNAlgorithmV::LinearSearch(ref linear) => linear.find(from, k),
|
||||||
|
KNNAlgorithmV::CoverTree(ref cover) => cover.find(from, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.classes.len() != other.classes.len() ||
|
||||||
|
self.k != other.k ||
|
||||||
|
self.y.len() != other.y.len() {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
for i in 0..self.classes.len() {
|
||||||
|
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i in 0..self.y.len() {
|
||||||
|
if self.y[i] != other.y[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||||
|
|
||||||
|
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, k: usize, distance: D, algorithm: KNNAlgorithmName) -> KNNClassifier<T, D> {
|
||||||
|
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
|
|
||||||
@@ -34,12 +91,7 @@ impl<'a, T: FloatExt> KNNClassifier<'a, T> {
|
|||||||
|
|
||||||
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
|
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
|
||||||
|
|
||||||
let knn_algorithm: Box<dyn KNNAlgorithm<Vec<T>> + 'a> = match algorithm {
|
KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: algorithm.fit(data, distance)}
|
||||||
KNNAlgorithmName::CoverTree => Box::new(CoverTree::<Vec<T>, T>::new(data, distance)),
|
|
||||||
KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::<Vec<T>, T>::new(data, distance))
|
|
||||||
};
|
|
||||||
|
|
||||||
KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: knn_algorithm}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,7 +126,7 @@ impl<'a, T: FloatExt> KNNClassifier<'a, T> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::math::distance::euclidian;
|
use crate::math::distance::Distances;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -86,9 +138,27 @@ mod tests {
|
|||||||
&[7., 8.],
|
&[7., 8.],
|
||||||
&[9., 10.]]);
|
&[9., 10.]]);
|
||||||
let y = vec![2., 2., 2., 3., 3.];
|
let y = vec![2., 2., 2., 3., 3.];
|
||||||
let knn = KNNClassifier::fit(&x, &y, 3, &euclidian::distance, KNNAlgorithmName::LinearSearch);
|
let knn = KNNClassifier::fit(&x, &y, 3, Distances::euclidian(), KNNAlgorithmName::LinearSearch);
|
||||||
let r = knn.predict(&x);
|
let r = knn.predict(&x);
|
||||||
assert_eq!(5, Vec::len(&r));
|
assert_eq!(5, Vec::len(&r));
|
||||||
assert_eq!(y.to_vec(), r);
|
assert_eq!(y.to_vec(), r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde() {
|
||||||
|
let x = DenseMatrix::from_array(&[
|
||||||
|
&[1., 2.],
|
||||||
|
&[3., 4.],
|
||||||
|
&[5., 6.],
|
||||||
|
&[7., 8.],
|
||||||
|
&[9., 10.]]);
|
||||||
|
let y = vec![2., 2., 2., 3., 3.];
|
||||||
|
|
||||||
|
let knn = KNNClassifier::fit(&x, &y, 3, Distances::euclidian(), KNNAlgorithmName::CoverTree);
|
||||||
|
|
||||||
|
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(knn, deserialized_knn);
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -3,11 +3,13 @@ use std::fmt::Debug;
|
|||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::collections::LinkedList;
|
use std::collections::LinkedList;
|
||||||
|
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct DecisionTreeClassifierParameters {
|
pub struct DecisionTreeClassifierParameters {
|
||||||
pub criterion: SplitCriterion,
|
pub criterion: SplitCriterion,
|
||||||
pub max_depth: Option<u16>,
|
pub max_depth: Option<u16>,
|
||||||
@@ -15,7 +17,7 @@ pub struct DecisionTreeClassifierParameters {
|
|||||||
pub min_samples_split: usize
|
pub min_samples_split: usize
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct DecisionTreeClassifier<T: FloatExt> {
|
pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||||
nodes: Vec<Node<T>>,
|
nodes: Vec<Node<T>>,
|
||||||
parameters: DecisionTreeClassifierParameters,
|
parameters: DecisionTreeClassifierParameters,
|
||||||
@@ -24,24 +26,62 @@ pub struct DecisionTreeClassifier<T: FloatExt> {
|
|||||||
depth: u16
|
depth: u16
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub enum SplitCriterion {
|
pub enum SplitCriterion {
|
||||||
Gini,
|
Gini,
|
||||||
Entropy,
|
Entropy,
|
||||||
ClassificationError
|
ClassificationError
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct Node<T: FloatExt> {
|
pub struct Node<T: FloatExt> {
|
||||||
index: usize,
|
index: usize,
|
||||||
output: usize,
|
output: usize,
|
||||||
split_feature: usize,
|
split_feature: usize,
|
||||||
split_value: T,
|
split_value: Option<T>,
|
||||||
split_score: T,
|
split_score: Option<T>,
|
||||||
true_child: Option<usize>,
|
true_child: Option<usize>,
|
||||||
false_child: Option<usize>,
|
false_child: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.depth != other.depth ||
|
||||||
|
self.num_classes != other.num_classes ||
|
||||||
|
self.nodes.len() != other.nodes.len(){
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
for i in 0..self.classes.len() {
|
||||||
|
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i in 0..self.nodes.len() {
|
||||||
|
if self.nodes[i] != other.nodes[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt> PartialEq for Node<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.output == other.output &&
|
||||||
|
self.split_feature == other.split_feature &&
|
||||||
|
match (self.split_value, other.split_value) {
|
||||||
|
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||||
|
(None, None) => true,
|
||||||
|
_ => false,
|
||||||
|
} &&
|
||||||
|
match (self.split_score, other.split_score) {
|
||||||
|
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||||
|
(None, None) => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for DecisionTreeClassifierParameters {
|
impl Default for DecisionTreeClassifierParameters {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
@@ -60,8 +100,8 @@ impl<T: FloatExt> Node<T> {
|
|||||||
index: index,
|
index: index,
|
||||||
output: output,
|
output: output,
|
||||||
split_feature: 0,
|
split_feature: 0,
|
||||||
split_value: T::nan(),
|
split_value: Option::None,
|
||||||
split_score: T::nan(),
|
split_score: Option::None,
|
||||||
true_child: Option::None,
|
true_child: Option::None,
|
||||||
false_child: Option::None
|
false_child: Option::None
|
||||||
}
|
}
|
||||||
@@ -238,7 +278,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
if node.true_child == None && node.false_child == None {
|
if node.true_child == None && node.false_child == None {
|
||||||
result = node.output;
|
result = node.output;
|
||||||
} else {
|
} else {
|
||||||
if x.get(row, node.split_feature) <= node.split_value {
|
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
|
||||||
queue.push_back(node.true_child.unwrap());
|
queue.push_back(node.true_child.unwrap());
|
||||||
} else {
|
} else {
|
||||||
queue.push_back(node.false_child.unwrap());
|
queue.push_back(node.false_child.unwrap());
|
||||||
@@ -299,7 +339,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
|
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
|
||||||
}
|
}
|
||||||
|
|
||||||
!self.nodes[visitor.node].split_score.is_nan()
|
self.nodes[visitor.node].split_score != Option::None
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,10 +376,10 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
let false_label = which_max(false_count);
|
let false_label = which_max(false_count);
|
||||||
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
|
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
|
||||||
|
|
||||||
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
|
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
||||||
self.nodes[visitor.node].split_feature = j;
|
self.nodes[visitor.node].split_feature = j;
|
||||||
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
|
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||||
self.nodes[visitor.node].split_score = gain;
|
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||||
visitor.true_child_output = true_label;
|
visitor.true_child_output = true_label;
|
||||||
visitor.false_child_output = false_label;
|
visitor.false_child_output = false_label;
|
||||||
}
|
}
|
||||||
@@ -360,7 +400,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
if visitor.samples[i] > 0 {
|
if visitor.samples[i] > 0 {
|
||||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value {
|
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
||||||
true_samples[i] = visitor.samples[i];
|
true_samples[i] = visitor.samples[i];
|
||||||
tc += true_samples[i];
|
tc += true_samples[i];
|
||||||
visitor.samples[i] = 0;
|
visitor.samples[i] = 0;
|
||||||
@@ -372,8 +412,8 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
|
|
||||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||||
self.nodes[visitor.node].split_feature = 0;
|
self.nodes[visitor.node].split_feature = 0;
|
||||||
self.nodes[visitor.node].split_value = T::nan();
|
self.nodes[visitor.node].split_value = Option::None;
|
||||||
self.nodes[visitor.node].split_score = T::nan();
|
self.nodes[visitor.node].split_score = Option::None;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -477,4 +517,37 @@ mod tests {
|
|||||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde() {
|
||||||
|
let x = DenseMatrix::from_array(&[
|
||||||
|
&[1.,1.,1.,0.],
|
||||||
|
&[1.,1.,1.,0.],
|
||||||
|
&[1.,1.,1.,1.],
|
||||||
|
&[1.,1.,0.,0.],
|
||||||
|
&[1.,1.,0.,1.],
|
||||||
|
&[1.,0.,1.,0.],
|
||||||
|
&[1.,0.,1.,0.],
|
||||||
|
&[1.,0.,1.,1.],
|
||||||
|
&[1.,0.,0.,0.],
|
||||||
|
&[1.,0.,0.,1.],
|
||||||
|
&[0.,1.,1.,0.],
|
||||||
|
&[0.,1.,1.,0.],
|
||||||
|
&[0.,1.,1.,1.],
|
||||||
|
&[0.,1.,0.,0.],
|
||||||
|
&[0.,1.,0.,1.],
|
||||||
|
&[0.,0.,1.,0.],
|
||||||
|
&[0.,0.,1.,0.],
|
||||||
|
&[0.,0.,1.,1.],
|
||||||
|
&[0.,0.,0.,0.],
|
||||||
|
&[0.,0.,0.,1.]]);
|
||||||
|
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
||||||
|
|
||||||
|
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
||||||
|
|
||||||
|
let deserialized_tree: DecisionTreeClassifier<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tree, deserialized_tree);
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -2,31 +2,33 @@ use std::default::Default;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::collections::LinkedList;
|
use std::collections::LinkedList;
|
||||||
|
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct DecisionTreeRegressorParameters {
|
pub struct DecisionTreeRegressorParameters {
|
||||||
pub max_depth: Option<u16>,
|
pub max_depth: Option<u16>,
|
||||||
pub min_samples_leaf: usize,
|
pub min_samples_leaf: usize,
|
||||||
pub min_samples_split: usize
|
pub min_samples_split: usize
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct DecisionTreeRegressor<T: FloatExt> {
|
pub struct DecisionTreeRegressor<T: FloatExt> {
|
||||||
nodes: Vec<Node<T>>,
|
nodes: Vec<Node<T>>,
|
||||||
parameters: DecisionTreeRegressorParameters,
|
parameters: DecisionTreeRegressorParameters,
|
||||||
depth: u16
|
depth: u16
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct Node<T: FloatExt> {
|
pub struct Node<T: FloatExt> {
|
||||||
index: usize,
|
index: usize,
|
||||||
output: T,
|
output: T,
|
||||||
split_feature: usize,
|
split_feature: usize,
|
||||||
split_value: T,
|
split_value: Option<T>,
|
||||||
split_score: T,
|
split_score: Option<T>,
|
||||||
true_child: Option<usize>,
|
true_child: Option<usize>,
|
||||||
false_child: Option<usize>,
|
false_child: Option<usize>,
|
||||||
}
|
}
|
||||||
@@ -48,14 +50,46 @@ impl<T: FloatExt> Node<T> {
|
|||||||
index: index,
|
index: index,
|
||||||
output: output,
|
output: output,
|
||||||
split_feature: 0,
|
split_feature: 0,
|
||||||
split_value: T::nan(),
|
split_value: Option::None,
|
||||||
split_score: T::nan(),
|
split_score: Option::None,
|
||||||
true_child: Option::None,
|
true_child: Option::None,
|
||||||
false_child: Option::None
|
false_child: Option::None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt> PartialEq for Node<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
(self.output - other.output).abs() < T::epsilon() &&
|
||||||
|
self.split_feature == other.split_feature &&
|
||||||
|
match (self.split_value, other.split_value) {
|
||||||
|
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||||
|
(None, None) => true,
|
||||||
|
_ => false,
|
||||||
|
} &&
|
||||||
|
match (self.split_score, other.split_score) {
|
||||||
|
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||||
|
(None, None) => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.depth != other.depth || self.nodes.len() != other.nodes.len(){
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
for i in 0..self.nodes.len() {
|
||||||
|
if self.nodes[i] != other.nodes[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
||||||
x: &'a M,
|
x: &'a M,
|
||||||
y: &'a M,
|
y: &'a M,
|
||||||
@@ -169,7 +203,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
if node.true_child == None && node.false_child == None {
|
if node.true_child == None && node.false_child == None {
|
||||||
result = node.output;
|
result = node.output;
|
||||||
} else {
|
} else {
|
||||||
if x.get(row, node.split_feature) <= node.split_value {
|
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
|
||||||
queue.push_back(node.true_child.unwrap());
|
queue.push_back(node.true_child.unwrap());
|
||||||
} else {
|
} else {
|
||||||
queue.push_back(node.false_child.unwrap());
|
queue.push_back(node.false_child.unwrap());
|
||||||
@@ -207,7 +241,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
|
self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
|
||||||
}
|
}
|
||||||
|
|
||||||
!self.nodes[visitor.node].split_score.is_nan()
|
self.nodes[visitor.node].split_score != Option::None
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,10 +274,10 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
|
|
||||||
let gain = (T::from(true_count).unwrap() * true_mean * true_mean + T::from(false_count).unwrap() * false_mean * false_mean) - parent_gain;
|
let gain = (T::from(true_count).unwrap() * true_mean * true_mean + T::from(false_count).unwrap() * false_mean * false_mean) - parent_gain;
|
||||||
|
|
||||||
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
|
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
||||||
self.nodes[visitor.node].split_feature = j;
|
self.nodes[visitor.node].split_feature = j;
|
||||||
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
|
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||||
self.nodes[visitor.node].split_score = gain;
|
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||||
visitor.true_child_output = true_mean;
|
visitor.true_child_output = true_mean;
|
||||||
visitor.false_child_output = false_mean;
|
visitor.false_child_output = false_mean;
|
||||||
}
|
}
|
||||||
@@ -264,7 +298,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
if visitor.samples[i] > 0 {
|
if visitor.samples[i] > 0 {
|
||||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value {
|
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
||||||
true_samples[i] = visitor.samples[i];
|
true_samples[i] = visitor.samples[i];
|
||||||
tc += true_samples[i];
|
tc += true_samples[i];
|
||||||
visitor.samples[i] = 0;
|
visitor.samples[i] = 0;
|
||||||
@@ -276,8 +310,8 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
|
|
||||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||||
self.nodes[visitor.node].split_feature = 0;
|
self.nodes[visitor.node].split_feature = 0;
|
||||||
self.nodes[visitor.node].split_value = T::nan();
|
self.nodes[visitor.node].split_value = Option::None;
|
||||||
self.nodes[visitor.node].split_score = T::nan();
|
self.nodes[visitor.node].split_score = Option::None;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -357,4 +391,33 @@ mod tests {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde() {
|
||||||
|
let x = DenseMatrix::from_array(&[
|
||||||
|
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
|
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
|
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||||
|
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||||
|
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||||
|
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
|
||||||
|
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||||
|
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||||
|
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||||
|
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
|
||||||
|
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||||
|
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
|
||||||
|
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
|
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
|
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
|
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||||
|
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||||
|
|
||||||
|
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
||||||
|
|
||||||
|
let deserialized_tree: DecisionTreeRegressor<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tree, deserialized_tree);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user