feat: adds serialization/deserialization methods
This commit is contained in:
@@ -2,7 +2,7 @@ use std::fmt::Debug;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian;
|
||||
use crate::math::distance::euclidian::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BBDTree<T: FloatExt> {
|
||||
@@ -79,10 +79,10 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
let d = centroids[0].len();
|
||||
|
||||
// Determine which mean the node mean is closest to
|
||||
let mut min_dist = euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||
let mut min_dist = Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||
let mut closest = candidates[0];
|
||||
for i in 1..k {
|
||||
let dist = euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||
let dist = Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
closest = candidates[i];
|
||||
|
||||
@@ -3,25 +3,26 @@ use std::iter::FromIterator;
|
||||
use std::fmt::Debug;
|
||||
use core::hash::{Hash, Hasher};
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::algorithm::neighbour::KNNAlgorithm;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
|
||||
pub struct CoverTree<'a, T, F: FloatExt>
|
||||
where T: Debug
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct CoverTree<T, F: FloatExt, D: Distance<T, F>>
|
||||
{
|
||||
base: F,
|
||||
max_level: i8,
|
||||
min_level: i8,
|
||||
distance: &'a dyn Fn(&T, &T) -> F,
|
||||
distance: D,
|
||||
nodes: Vec<Node<T>>
|
||||
}
|
||||
|
||||
impl<'a, T, F: FloatExt> CoverTree<'a, T, F>
|
||||
where T: Debug
|
||||
impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
{
|
||||
|
||||
pub fn new(mut data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> CoverTree<T, F> {
|
||||
pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
|
||||
let mut tree = CoverTree {
|
||||
base: F::two(),
|
||||
max_level: 100,
|
||||
@@ -43,7 +44,7 @@ where T: Debug
|
||||
} else {
|
||||
let mut parent: Option<NodeId> = Option::None;
|
||||
let mut p_i = 0;
|
||||
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
|
||||
let mut qi_p_ds = vec!((self.root(), D::distance(&p, &self.root().data)));
|
||||
let mut i = self.max_level;
|
||||
loop {
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
@@ -82,6 +83,18 @@ where T: Debug
|
||||
node_id
|
||||
}
|
||||
|
||||
pub fn find(&self, p: &T, k: usize) -> Vec<usize>{
|
||||
let mut qi_p_ds = vec!((self.root(), D::distance(&p, &self.root().data)));
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
||||
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
|
||||
}
|
||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
|
||||
}
|
||||
|
||||
fn split(&self, p_id: NodeId, r: F, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){
|
||||
|
||||
let mut my_near = (Vec::new(), Vec::new());
|
||||
@@ -102,7 +115,7 @@ where T: Debug
|
||||
let p = &self.nodes.get(p_id.index).unwrap().data;
|
||||
let mut i = 0;
|
||||
while i != s.len() {
|
||||
let d = (self.distance)(p, &s[i]);
|
||||
let d = D::distance(p, &s[i]);
|
||||
if d <= r {
|
||||
my_near.0.push(s.remove(i));
|
||||
} else if d > r && d <= F::two() * r{
|
||||
@@ -156,7 +169,7 @@ where T: Debug
|
||||
|
||||
let q: Vec<&Node<T>> = qi_p_ds.iter().flat_map(|(n, _)| self.get_child(n, i)).collect();
|
||||
|
||||
children.extend(q.into_iter().map(|n| (n, (self.distance)(&n.data, &p))));
|
||||
children.extend(q.into_iter().map(|n| (n, D::distance(&n.data, &p))));
|
||||
|
||||
children
|
||||
|
||||
@@ -180,7 +193,7 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn check_invariant(&self, invariant: fn(&CoverTree<T, F>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
|
||||
fn check_invariant(&self, invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
|
||||
let mut current_nodes: Vec<&Node<T>> = Vec::new();
|
||||
current_nodes.push(self.root());
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
@@ -193,7 +206,7 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn nesting_invariant(_: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
|
||||
fn nesting_invariant(_: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
|
||||
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
|
||||
let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
|
||||
for n in nodes_set.iter() {
|
||||
@@ -202,11 +215,11 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn covering_tree(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
|
||||
fn covering_tree(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
|
||||
let mut p_selected: Vec<&Node<T>> = Vec::new();
|
||||
for p in next_nodes {
|
||||
for q in nodes {
|
||||
if (tree.distance)(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
|
||||
if D::distance(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
|
||||
p_selected.push(*p);
|
||||
}
|
||||
}
|
||||
@@ -216,11 +229,11 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn separation(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
||||
fn separation(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
||||
for p in nodes {
|
||||
for q in nodes {
|
||||
if p != q {
|
||||
assert!((tree.distance)(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
|
||||
assert!(D::distance(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -228,28 +241,12 @@ where T: Debug
|
||||
|
||||
}
|
||||
|
||||
impl<'a, T, F: FloatExt> KNNAlgorithm<T> for CoverTree<'a, T, F>
|
||||
where T: Debug
|
||||
{
|
||||
fn find(&self, p: &T, k: usize) -> Vec<usize>{
|
||||
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
||||
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
|
||||
}
|
||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct NodeId {
|
||||
index: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Node<T> {
|
||||
index: NodeId,
|
||||
data: T,
|
||||
@@ -280,13 +277,19 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
struct SimpleDistance{}
|
||||
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
fn distance(a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cover_tree_test() {
|
||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
||||
let distance = |a: &i32, b: &i32| -> f64 {
|
||||
(a - b).abs() as f64
|
||||
};
|
||||
let mut tree = CoverTree::<i32, f64>::new(data, &distance);
|
||||
|
||||
let mut tree = CoverTree::new(data, SimpleDistance{});
|
||||
for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) {
|
||||
tree.insert(d);
|
||||
}
|
||||
@@ -306,10 +309,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_invariants(){
|
||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
||||
let distance = |a: &i32, b: &i32| -> f64 {
|
||||
(a - b).abs() as f64
|
||||
};
|
||||
let tree = CoverTree::<i32, f64>::new(data, &distance);
|
||||
|
||||
let tree = CoverTree::new(data, SimpleDistance{});
|
||||
tree.check_invariant(CoverTree::nesting_invariant);
|
||||
tree.check_invariant(CoverTree::covering_tree);
|
||||
tree.check_invariant(CoverTree::separation);
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
use crate::algorithm::neighbour::KNNAlgorithm;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
use std::cmp::{Ordering, PartialOrd};
|
||||
use num_traits::Float;
|
||||
use std::marker::PhantomData;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
pub struct LinearKNNSearch<'a, T, F: Float> {
|
||||
distance: Box<dyn Fn(&T, &T) -> F + 'a>,
|
||||
data: Vec<T>
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LinearKNNSearch<T, F: FloatExt, D: Distance<T, F>> {
|
||||
distance: D,
|
||||
data: Vec<T>,
|
||||
f: PhantomData<F>
|
||||
}
|
||||
|
||||
impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
|
||||
{
|
||||
fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
||||
impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D>{
|
||||
LinearKNNSearch{
|
||||
data: data,
|
||||
distance: distance,
|
||||
f: PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
}
|
||||
@@ -19,14 +31,14 @@ impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
|
||||
|
||||
for _ in 0..k {
|
||||
heap.add(KNNPoint{
|
||||
distance: Float::infinity(),
|
||||
distance: F::infinity(),
|
||||
index: None
|
||||
});
|
||||
}
|
||||
|
||||
for i in 0..self.data.len() {
|
||||
|
||||
let d = (self.distance)(&from, &self.data[i]);
|
||||
let d = D::distance(&from, &self.data[i]);
|
||||
let datum = heap.peek_mut();
|
||||
if d < datum.distance {
|
||||
datum.distance = d;
|
||||
@@ -41,43 +53,34 @@ impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, F: Float> LinearKNNSearch<'a, T, F> {
|
||||
pub fn new(data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> LinearKNNSearch<T, F>{
|
||||
LinearKNNSearch{
|
||||
data: data,
|
||||
distance: Box::new(distance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct KNNPoint<F: Float> {
|
||||
struct KNNPoint<F: FloatExt> {
|
||||
distance: F,
|
||||
index: Option<usize>
|
||||
}
|
||||
|
||||
impl<F: Float> PartialOrd for KNNPoint<F> {
|
||||
impl<F: FloatExt> PartialOrd for KNNPoint<F> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Float> PartialEq for KNNPoint<F> {
|
||||
impl<F: FloatExt> PartialEq for KNNPoint<F> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Float> Eq for KNNPoint<F> {}
|
||||
impl<F: FloatExt> Eq for KNNPoint<F> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::euclidian;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
struct SimpleDistance{}
|
||||
|
||||
impl SimpleDistance {
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
fn distance(a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
@@ -87,13 +90,13 @@ mod tests {
|
||||
fn knn_find() {
|
||||
let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||
|
||||
let algorithm1 = LinearKNNSearch::new(data1, &SimpleDistance::distance);
|
||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance{});
|
||||
|
||||
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
||||
|
||||
let data2 = vec!(vec![1., 1.], vec![2., 2.], vec![3., 3.], vec![4., 4.], vec![5., 5.]);
|
||||
|
||||
let algorithm2 = LinearKNNSearch::new(data2, &euclidian::distance);
|
||||
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
||||
|
||||
assert_eq!(vec!(2, 3, 1), algorithm2.find(&vec![3., 3.], 3));
|
||||
}
|
||||
@@ -116,7 +119,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let point_inf = KNNPoint{
|
||||
distance: Float::infinity(),
|
||||
distance: std::f64::INFINITY,
|
||||
index: Some(3)
|
||||
};
|
||||
|
||||
|
||||
@@ -1,12 +1,3 @@
|
||||
pub mod cover_tree;
|
||||
pub mod linear_search;
|
||||
pub mod bbd_tree;
|
||||
|
||||
pub enum KNNAlgorithmName {
|
||||
CoverTree,
|
||||
LinearSearch,
|
||||
}
|
||||
|
||||
pub trait KNNAlgorithm<T>{
|
||||
fn find(&self, from: &T, k: usize) -> Vec<usize>;
|
||||
}
|
||||
pub mod bbd_tree;
|
||||
@@ -8,7 +8,7 @@ use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian;
|
||||
use crate::math::distance::euclidian::*;
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
@@ -130,7 +130,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
||||
let mut best_cluster = 0;
|
||||
|
||||
for j in 0..self.k {
|
||||
let dist = euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
|
||||
let dist = Euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
best_cluster = j;
|
||||
@@ -156,7 +156,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
||||
// the distance from each sample to its closest center in scores.
|
||||
for i in 0..n {
|
||||
// compute the distance between this sample and the current center
|
||||
let dist = euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||
|
||||
if dist < d[i] {
|
||||
d[i] = dist;
|
||||
@@ -183,7 +183,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
||||
|
||||
for i in 0..n {
|
||||
// compute the distance between this sample and the current center
|
||||
let dist = euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||
|
||||
if dist < d[i] {
|
||||
d[i] = dist;
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::{Matrix};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct PCA<T: FloatExt, M: Matrix<T>> {
|
||||
eigenvectors: M,
|
||||
eigenvalues: Vec<T>,
|
||||
@@ -11,6 +14,22 @@ pub struct PCA<T: FloatExt, M: Matrix<T>> {
|
||||
pmu: Vec<T>
|
||||
}
|
||||
|
||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for PCA<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.eigenvectors != other.eigenvectors ||
|
||||
self.eigenvalues.len() != other.eigenvalues.len() {
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.eigenvalues.len() {
|
||||
if (self.eigenvalues[i] - other.eigenvalues[i]).abs() > T::epsilon() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PCAParameters {
|
||||
use_correlation_matrix: bool
|
||||
@@ -366,5 +385,37 @@ mod tests {
|
||||
assert!(us_arrests_t.abs().approximate_eq(&expected_projection.abs(), 1e-4));
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let iris = DenseMatrix::from_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4]]);
|
||||
|
||||
let pca = PCA::new(&iris, 4, Default::default());
|
||||
|
||||
let deserialized_pca: PCA<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(pca, deserialized_pca);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -4,12 +4,13 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::Rng;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct RandomForestClassifierParameters {
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
@@ -19,13 +20,34 @@ pub struct RandomForestClassifierParameters {
|
||||
pub mtry: Option<usize>
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct RandomForestClassifier<T: FloatExt> {
|
||||
parameters: RandomForestClassifierParameters,
|
||||
trees: Vec<DecisionTreeClassifier<T>>,
|
||||
classes: Vec<T>
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for RandomForestClassifier<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.classes.len() != other.classes.len() ||
|
||||
self.trees.len() != other.trees.len() {
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for i in 0..self.trees.len() {
|
||||
if self.trees[i] != other.trees[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RandomForestClassifierParameters {
|
||||
fn default() -> Self {
|
||||
RandomForestClassifierParameters {
|
||||
@@ -171,4 +193,37 @@ mod tests {
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4]]);
|
||||
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||
|
||||
let forest = RandomForestClassifier::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_forest: RandomForestClassifier<f64> = bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(forest, deserialized_forest);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -4,12 +4,13 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::Rng;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::tree::decision_tree_regressor::{DecisionTreeRegressor, DecisionTreeRegressorParameters};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct RandomForestRegressorParameters {
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: usize,
|
||||
@@ -18,7 +19,7 @@ pub struct RandomForestRegressorParameters {
|
||||
pub mtry: Option<usize>
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct RandomForestRegressor<T: FloatExt> {
|
||||
parameters: RandomForestRegressorParameters,
|
||||
trees: Vec<DecisionTreeRegressor<T>>
|
||||
@@ -36,6 +37,21 @@ impl Default for RandomForestRegressorParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for RandomForestRegressor<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.trees.len() != other.trees.len() {
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.trees.len() {
|
||||
if self.trees[i] != other.trees[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> RandomForestRegressor<T> {
|
||||
|
||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: RandomForestRegressorParameters) -> RandomForestRegressor<T> {
|
||||
@@ -180,4 +196,33 @@ mod tests {
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
|
||||
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
|
||||
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
|
||||
let forest = RandomForestRegressor::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_forest: RandomForestRegressor<f64> = bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(forest, deserialized_forest);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -21,7 +21,7 @@ pub struct LinearRegression<T: FloatExt, M: Matrix<T>> {
|
||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients &&
|
||||
self.intercept == other.intercept
|
||||
(self.intercept - other.intercept).abs() <= T::epsilon()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -42,10 +42,19 @@ struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
|
||||
self.num_classes == other.num_classes &&
|
||||
self.classes == other.classes &&
|
||||
self.num_attributes == other.num_attributes &&
|
||||
self.weights == other.weights
|
||||
if self.num_classes != other.num_classes ||
|
||||
self.num_attributes != other.num_attributes ||
|
||||
self.classes.len() != other.classes.len() {
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon(){
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return self.weights == other.weights
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,39 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
pub fn distance<T: FloatExt>(x: &Vec<T>, y: &Vec<T>) -> T {
|
||||
return squared_distance(x, y).sqrt();
|
||||
use super::Distance;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Euclidian {
|
||||
}
|
||||
|
||||
pub fn squared_distance<T: FloatExt>(x: &Vec<T>,y: &Vec<T>) -> T {
|
||||
if x.len() != y.len() {
|
||||
panic!("Input vector sizes are different.");
|
||||
impl Euclidian {
|
||||
pub fn squared_distance<T: FloatExt>(x: &Vec<T>,y: &Vec<T>) -> T {
|
||||
if x.len() != y.len() {
|
||||
panic!("Input vector sizes are different.");
|
||||
}
|
||||
|
||||
let mut sum = T::zero();
|
||||
for i in 0..x.len() {
|
||||
sum = sum + (x[i] - y[i]).powf(T::two());
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
let mut sum = T::zero();
|
||||
for i in 0..x.len() {
|
||||
sum = sum + (x[i] - y[i]).powf(T::two());
|
||||
pub fn distance<T: FloatExt>(x: &Vec<T>, y: &Vec<T>) -> T {
|
||||
Euclidian::squared_distance(x, y).sqrt()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl<T: FloatExt> Distance<Vec<T>, T> for Euclidian {
|
||||
|
||||
fn distance(x: &Vec<T>, y: &Vec<T>) -> T {
|
||||
Self::distance(x, y)
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +46,7 @@ mod tests {
|
||||
let a = vec![1., 2., 3.];
|
||||
let b = vec![4., 5., 6.];
|
||||
|
||||
let d_arr: f64 = distance(&a, &b);
|
||||
let d_arr: f64 = Euclidian::distance(&a, &b);
|
||||
|
||||
assert!((d_arr - 5.19615242).abs() < 1e-8);
|
||||
}
|
||||
|
||||
@@ -1 +1,16 @@
|
||||
pub mod euclidian;
|
||||
pub mod euclidian;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
pub trait Distance<T, F: FloatExt>{
|
||||
fn distance(a: &T, b: &T) -> F;
|
||||
}
|
||||
|
||||
pub struct Distances{
|
||||
}
|
||||
|
||||
impl Distances {
|
||||
pub fn euclidian() -> euclidian::Euclidian{
|
||||
euclidian::Euclidian {}
|
||||
}
|
||||
}
|
||||
+88
-18
@@ -1,26 +1,83 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::linalg::{Matrix, row_iter};
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||
|
||||
pub struct KNNClassifier<'a, T: FloatExt> {
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
classes: Vec<T>,
|
||||
y: Vec<usize>,
|
||||
knn_algorithm: Box<dyn KNNAlgorithm<Vec<T>> + 'a>,
|
||||
k: usize,
|
||||
knn_algorithm: KNNAlgorithmV<T, D>,
|
||||
k: usize
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt> KNNClassifier<'a, T> {
|
||||
pub enum KNNAlgorithmName {
|
||||
LinearSearch,
|
||||
CoverTree
|
||||
}
|
||||
|
||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, k: usize, distance: &'a dyn Fn(&Vec<T>, &Vec<T>) -> T, algorithm: KNNAlgorithmName) -> KNNClassifier<'a, T> {
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum KNNAlgorithmV<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||
CoverTree(CoverTree<Vec<T>, T, D>)
|
||||
}
|
||||
|
||||
impl KNNAlgorithmName {
|
||||
|
||||
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(&self, data: Vec<Vec<T>>, distance: D) -> KNNAlgorithmV<T, D> {
|
||||
match *self {
|
||||
KNNAlgorithmName::LinearSearch => KNNAlgorithmV::LinearSearch(LinearKNNSearch::new(data, distance)),
|
||||
KNNAlgorithmName::CoverTree => KNNAlgorithmV::CoverTree(CoverTree::new(data, distance)),
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithmV<T, D> {
|
||||
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize>{
|
||||
match *self {
|
||||
KNNAlgorithmV::LinearSearch(ref linear) => linear.find(from, k),
|
||||
KNNAlgorithmV::CoverTree(ref cover) => cover.find(from, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.classes.len() != other.classes.len() ||
|
||||
self.k != other.k ||
|
||||
self.y.len() != other.y.len() {
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for i in 0..self.y.len() {
|
||||
if self.y[i] != other.y[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
|
||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, k: usize, distance: D, algorithm: KNNAlgorithmName) -> KNNClassifier<T, D> {
|
||||
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
let (_, y_n) = y_m.shape();
|
||||
let (x_n, _) = x.shape();
|
||||
|
||||
let data = row_iter(x).collect();
|
||||
let data = row_iter(x).collect();
|
||||
|
||||
let mut yi: Vec<usize> = vec![0; y_n];
|
||||
let classes = y_m.unique();
|
||||
@@ -32,14 +89,9 @@ impl<'a, T: FloatExt> KNNClassifier<'a, T> {
|
||||
|
||||
assert!(x_n == y_n, format!("Size of x should equal size of y; |x|=[{}], |y|=[{}]", x_n, y_n));
|
||||
|
||||
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
|
||||
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
|
||||
|
||||
let knn_algorithm: Box<dyn KNNAlgorithm<Vec<T>> + 'a> = match algorithm {
|
||||
KNNAlgorithmName::CoverTree => Box::new(CoverTree::<Vec<T>, T>::new(data, distance)),
|
||||
KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::<Vec<T>, T>::new(data, distance))
|
||||
};
|
||||
|
||||
KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: knn_algorithm}
|
||||
KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: algorithm.fit(data, distance)}
|
||||
|
||||
}
|
||||
|
||||
@@ -74,8 +126,8 @@ impl<'a, T: FloatExt> KNNClassifier<'a, T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::euclidian;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::math::distance::Distances;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn knn_fit_predict() {
|
||||
@@ -85,10 +137,28 @@ mod tests {
|
||||
&[5., 6.],
|
||||
&[7., 8.],
|
||||
&[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
let knn = KNNClassifier::fit(&x, &y, 3, &euclidian::distance, KNNAlgorithmName::LinearSearch);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
let knn = KNNClassifier::fit(&x, &y, 3, Distances::euclidian(), KNNAlgorithmName::LinearSearch);
|
||||
let r = knn.predict(&x);
|
||||
assert_eq!(5, Vec::len(&r));
|
||||
assert_eq!(y.to_vec(), r);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1., 2.],
|
||||
&[3., 4.],
|
||||
&[5., 6.],
|
||||
&[7., 8.],
|
||||
&[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
|
||||
let knn = KNNClassifier::fit(&x, &y, 3, Distances::euclidian(), KNNAlgorithmName::CoverTree);
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(knn, deserialized_knn);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -3,11 +3,13 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::collections::LinkedList;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
@@ -15,7 +17,7 @@ pub struct DecisionTreeClassifierParameters {
|
||||
pub min_samples_split: usize
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
@@ -24,24 +26,62 @@ pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||
depth: u16
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum SplitCriterion {
|
||||
Gini,
|
||||
Entropy,
|
||||
ClassificationError
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: FloatExt> {
|
||||
index: usize,
|
||||
output: usize,
|
||||
split_feature: usize,
|
||||
split_value: T,
|
||||
split_score: T,
|
||||
split_value: Option<T>,
|
||||
split_score: Option<T>,
|
||||
true_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth ||
|
||||
self.num_classes != other.num_classes ||
|
||||
self.nodes.len() != other.nodes.len(){
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for i in 0..self.nodes.len() {
|
||||
if self.nodes[i] != other.nodes[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.output == other.output &&
|
||||
self.split_feature == other.split_feature &&
|
||||
match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
} &&
|
||||
match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeClassifierParameters {
|
||||
fn default() -> Self {
|
||||
@@ -60,8 +100,8 @@ impl<T: FloatExt> Node<T> {
|
||||
index: index,
|
||||
output: output,
|
||||
split_feature: 0,
|
||||
split_value: T::nan(),
|
||||
split_score: T::nan(),
|
||||
split_value: Option::None,
|
||||
split_score: Option::None,
|
||||
true_child: Option::None,
|
||||
false_child: Option::None
|
||||
}
|
||||
@@ -238,7 +278,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
if node.true_child == None && node.false_child == None {
|
||||
result = node.output;
|
||||
} else {
|
||||
if x.get(row, node.split_feature) <= node.split_value {
|
||||
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
@@ -299,7 +339,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
|
||||
}
|
||||
|
||||
!self.nodes[visitor.node].split_score.is_nan()
|
||||
self.nodes[visitor.node].split_score != Option::None
|
||||
|
||||
}
|
||||
|
||||
@@ -336,10 +376,10 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
let false_label = which_max(false_count);
|
||||
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
|
||||
|
||||
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
|
||||
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
||||
self.nodes[visitor.node].split_feature = j;
|
||||
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
|
||||
self.nodes[visitor.node].split_score = gain;
|
||||
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||
visitor.true_child_output = true_label;
|
||||
visitor.false_child_output = false_label;
|
||||
}
|
||||
@@ -360,7 +400,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
|
||||
for i in 0..n {
|
||||
if visitor.samples[i] > 0 {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
||||
true_samples[i] = visitor.samples[i];
|
||||
tc += true_samples[i];
|
||||
visitor.samples[i] = 0;
|
||||
@@ -372,8 +412,8 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
|
||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||
self.nodes[visitor.node].split_feature = 0;
|
||||
self.nodes[visitor.node].split_value = T::nan();
|
||||
self.nodes[visitor.node].split_score = T::nan();
|
||||
self.nodes[visitor.node].split_value = Option::None;
|
||||
self.nodes[visitor.node].split_score = Option::None;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -477,4 +517,37 @@ mod tests {
|
||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,1.],
|
||||
&[1.,1.,0.,0.],
|
||||
&[1.,1.,0.,1.],
|
||||
&[1.,0.,1.,0.],
|
||||
&[1.,0.,1.,0.],
|
||||
&[1.,0.,1.,1.],
|
||||
&[1.,0.,0.,0.],
|
||||
&[1.,0.,0.,1.],
|
||||
&[0.,1.,1.,0.],
|
||||
&[0.,1.,1.,0.],
|
||||
&[0.,1.,1.,1.],
|
||||
&[0.,1.,0.,0.],
|
||||
&[0.,1.,0.,1.],
|
||||
&[0.,0.,1.,0.],
|
||||
&[0.,0.,1.,0.],
|
||||
&[0.,0.,1.,1.],
|
||||
&[0.,0.,0.,0.],
|
||||
&[0.,0.,0.,1.]]);
|
||||
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
||||
|
||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_tree: DecisionTreeClassifier<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(tree, deserialized_tree);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -2,31 +2,33 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
use std::collections::LinkedList;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeRegressorParameters {
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: usize,
|
||||
pub min_samples_split: usize
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeRegressor<T: FloatExt> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
depth: u16
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: FloatExt> {
|
||||
index: usize,
|
||||
output: T,
|
||||
split_feature: usize,
|
||||
split_value: T,
|
||||
split_score: T,
|
||||
split_value: Option<T>,
|
||||
split_score: Option<T>,
|
||||
true_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
}
|
||||
@@ -48,14 +50,46 @@ impl<T: FloatExt> Node<T> {
|
||||
index: index,
|
||||
output: output,
|
||||
split_feature: 0,
|
||||
split_value: T::nan(),
|
||||
split_score: T::nan(),
|
||||
split_value: Option::None,
|
||||
split_score: Option::None,
|
||||
true_child: Option::None,
|
||||
false_child: Option::None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
(self.output - other.output).abs() < T::epsilon() &&
|
||||
self.split_feature == other.split_feature &&
|
||||
match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
} &&
|
||||
match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth || self.nodes.len() != other.nodes.len(){
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.nodes.len() {
|
||||
if self.nodes[i] != other.nodes[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: &'a M,
|
||||
@@ -169,7 +203,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
if node.true_child == None && node.false_child == None {
|
||||
result = node.output;
|
||||
} else {
|
||||
if x.get(row, node.split_feature) <= node.split_value {
|
||||
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
@@ -207,7 +241,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
|
||||
}
|
||||
|
||||
!self.nodes[visitor.node].split_score.is_nan()
|
||||
self.nodes[visitor.node].split_score != Option::None
|
||||
|
||||
}
|
||||
|
||||
@@ -240,10 +274,10 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
|
||||
let gain = (T::from(true_count).unwrap() * true_mean * true_mean + T::from(false_count).unwrap() * false_mean * false_mean) - parent_gain;
|
||||
|
||||
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
|
||||
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
||||
self.nodes[visitor.node].split_feature = j;
|
||||
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
|
||||
self.nodes[visitor.node].split_score = gain;
|
||||
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||
visitor.true_child_output = true_mean;
|
||||
visitor.false_child_output = false_mean;
|
||||
}
|
||||
@@ -264,7 +298,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
|
||||
for i in 0..n {
|
||||
if visitor.samples[i] > 0 {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
||||
true_samples[i] = visitor.samples[i];
|
||||
tc += true_samples[i];
|
||||
visitor.samples[i] = 0;
|
||||
@@ -276,8 +310,8 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
|
||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||
self.nodes[visitor.node].split_feature = 0;
|
||||
self.nodes[visitor.node].split_value = T::nan();
|
||||
self.nodes[visitor.node].split_score = T::nan();
|
||||
self.nodes[visitor.node].split_value = Option::None;
|
||||
self.nodes[visitor.node].split_score = Option::None;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -357,4 +391,33 @@ mod tests {
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
|
||||
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
|
||||
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
|
||||
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_tree: DecisionTreeRegressor<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(tree, deserialized_tree);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user