fix: cargo fmt
This commit is contained in:
+4
-2
@@ -2,14 +2,16 @@
|
|||||||
extern crate criterion;
|
extern crate criterion;
|
||||||
extern crate smartcore;
|
extern crate smartcore;
|
||||||
|
|
||||||
use criterion::Criterion;
|
|
||||||
use criterion::black_box;
|
use criterion::black_box;
|
||||||
|
use criterion::Criterion;
|
||||||
use smartcore::math::distance::*;
|
use smartcore::math::distance::*;
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
let a = vec![1., 2., 3.];
|
let a = vec![1., 2., 3.];
|
||||||
|
|
||||||
c.bench_function("Euclidean Distance", move |b| b.iter(|| Distances::euclidian().distance(black_box(&a), black_box(&a))));
|
c.bench_function("Euclidean Distance", move |b| {
|
||||||
|
b.iter(|| Distances::euclidian().distance(black_box(&a), black_box(&a)))
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
criterion_group!(benches, criterion_benchmark);
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
pub mod sort;
|
|
||||||
pub mod neighbour;
|
pub mod neighbour;
|
||||||
|
pub mod sort;
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::distance::euclidian::*;
|
use crate::math::distance::euclidian::*;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct BBDTree<T: FloatExt> {
|
pub struct BBDTree<T: FloatExt> {
|
||||||
nodes: Vec<BBDTreeNode<T>>,
|
nodes: Vec<BBDTreeNode<T>>,
|
||||||
index: Vec<usize>,
|
index: Vec<usize>,
|
||||||
root: usize
|
root: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -20,7 +20,7 @@ struct BBDTreeNode<T: FloatExt> {
|
|||||||
sum: Vec<T>,
|
sum: Vec<T>,
|
||||||
cost: T,
|
cost: T,
|
||||||
lower: Option<usize>,
|
lower: Option<usize>,
|
||||||
upper: Option<usize>
|
upper: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> BBDTreeNode<T> {
|
impl<T: FloatExt> BBDTreeNode<T> {
|
||||||
@@ -33,7 +33,7 @@ impl<T: FloatExt> BBDTreeNode<T> {
|
|||||||
sum: vec![T::zero(); d],
|
sum: vec![T::zero(); d],
|
||||||
cost: T::zero(),
|
cost: T::zero(),
|
||||||
lower: Option::None,
|
lower: Option::None,
|
||||||
upper: Option::None
|
upper: Option::None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -52,7 +52,7 @@ impl<T: FloatExt> BBDTree<T> {
|
|||||||
let mut tree = BBDTree {
|
let mut tree = BBDTree {
|
||||||
nodes: nodes,
|
nodes: nodes,
|
||||||
index: index,
|
index: index,
|
||||||
root: 0
|
root: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let root = tree.build_node(data, 0, n);
|
let root = tree.build_node(data, 0, n);
|
||||||
@@ -62,7 +62,13 @@ impl<T: FloatExt> BBDTree<T> {
|
|||||||
tree
|
tree
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(in crate) fn clustering(&self, centroids: &Vec<Vec<T>>, sums: &mut Vec<Vec<T>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> T {
|
pub(in crate) fn clustering(
|
||||||
|
&self,
|
||||||
|
centroids: &Vec<Vec<T>>,
|
||||||
|
sums: &mut Vec<Vec<T>>,
|
||||||
|
counts: &mut Vec<usize>,
|
||||||
|
membership: &mut Vec<usize>,
|
||||||
|
) -> T {
|
||||||
let k = centroids.len();
|
let k = centroids.len();
|
||||||
|
|
||||||
counts.iter_mut().for_each(|x| *x = 0);
|
counts.iter_mut().for_each(|x| *x = 0);
|
||||||
@@ -72,17 +78,36 @@ impl<T: FloatExt> BBDTree<T> {
|
|||||||
sums[i].iter_mut().for_each(|x| *x = T::zero());
|
sums[i].iter_mut().for_each(|x| *x = T::zero());
|
||||||
}
|
}
|
||||||
|
|
||||||
self.filter(self.root, centroids, &candidates, k, sums, counts, membership)
|
self.filter(
|
||||||
|
self.root,
|
||||||
|
centroids,
|
||||||
|
&candidates,
|
||||||
|
k,
|
||||||
|
sums,
|
||||||
|
counts,
|
||||||
|
membership,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn filter(&self, node: usize, centroids: &Vec<Vec<T>>, candidates: &Vec<usize>, k: usize, sums: &mut Vec<Vec<T>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> T{
|
fn filter(
|
||||||
|
&self,
|
||||||
|
node: usize,
|
||||||
|
centroids: &Vec<Vec<T>>,
|
||||||
|
candidates: &Vec<usize>,
|
||||||
|
k: usize,
|
||||||
|
sums: &mut Vec<Vec<T>>,
|
||||||
|
counts: &mut Vec<usize>,
|
||||||
|
membership: &mut Vec<usize>,
|
||||||
|
) -> T {
|
||||||
let d = centroids[0].len();
|
let d = centroids[0].len();
|
||||||
|
|
||||||
// Determine which mean the node mean is closest to
|
// Determine which mean the node mean is closest to
|
||||||
let mut min_dist = Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
let mut min_dist =
|
||||||
|
Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||||
let mut closest = candidates[0];
|
let mut closest = candidates[0];
|
||||||
for i in 1..k {
|
for i in 1..k {
|
||||||
let dist = Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
let dist =
|
||||||
|
Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||||
if dist < min_dist {
|
if dist < min_dist {
|
||||||
min_dist = dist;
|
min_dist = dist;
|
||||||
closest = candidates[i];
|
closest = candidates[i];
|
||||||
@@ -96,7 +121,13 @@ impl<T: FloatExt> BBDTree<T> {
|
|||||||
let mut newk = 0;
|
let mut newk = 0;
|
||||||
|
|
||||||
for i in 0..k {
|
for i in 0..k {
|
||||||
if !BBDTree::prune(&self.nodes[node].center, &self.nodes[node].radius, ¢roids, closest, candidates[i]) {
|
if !BBDTree::prune(
|
||||||
|
&self.nodes[node].center,
|
||||||
|
&self.nodes[node].radius,
|
||||||
|
¢roids,
|
||||||
|
closest,
|
||||||
|
candidates[i],
|
||||||
|
) {
|
||||||
new_candidates[newk] = candidates[i];
|
new_candidates[newk] = candidates[i];
|
||||||
newk += 1;
|
newk += 1;
|
||||||
}
|
}
|
||||||
@@ -104,8 +135,23 @@ impl<T: FloatExt> BBDTree<T> {
|
|||||||
|
|
||||||
// Recurse if there's at least two
|
// Recurse if there's at least two
|
||||||
if newk > 1 {
|
if newk > 1 {
|
||||||
let result = self.filter(self.nodes[node].lower.unwrap(), centroids, &mut new_candidates, newk, sums, counts, membership) +
|
let result = self.filter(
|
||||||
self.filter(self.nodes[node].upper.unwrap(), centroids, &mut new_candidates, newk, sums, counts, membership);
|
self.nodes[node].lower.unwrap(),
|
||||||
|
centroids,
|
||||||
|
&mut new_candidates,
|
||||||
|
newk,
|
||||||
|
sums,
|
||||||
|
counts,
|
||||||
|
membership,
|
||||||
|
) + self.filter(
|
||||||
|
self.nodes[node].upper.unwrap(),
|
||||||
|
centroids,
|
||||||
|
&mut new_candidates,
|
||||||
|
newk,
|
||||||
|
sums,
|
||||||
|
counts,
|
||||||
|
membership,
|
||||||
|
);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -123,10 +169,15 @@ impl<T: FloatExt> BBDTree<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
BBDTree::node_cost(&self.nodes[node], ¢roids[closest])
|
BBDTree::node_cost(&self.nodes[node], ¢roids[closest])
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prune(center: &Vec<T>, radius: &Vec<T>, centroids: &Vec<Vec<T>>, best_index: usize, test_index: usize) -> bool {
|
fn prune(
|
||||||
|
center: &Vec<T>,
|
||||||
|
radius: &Vec<T>,
|
||||||
|
centroids: &Vec<Vec<T>>,
|
||||||
|
best_index: usize,
|
||||||
|
test_index: usize,
|
||||||
|
) -> bool {
|
||||||
if best_index == test_index {
|
if best_index == test_index {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -247,7 +298,8 @@ impl<T: FloatExt> BBDTree<T> {
|
|||||||
|
|
||||||
// Calculate the new sum and opt cost
|
// Calculate the new sum and opt cost
|
||||||
for i in 0..d {
|
for i in 0..d {
|
||||||
node.sum[i] = self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
|
node.sum[i] =
|
||||||
|
self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut mean = vec![T::zero(); d];
|
let mut mean = vec![T::zero(); d];
|
||||||
@@ -255,7 +307,8 @@ impl<T: FloatExt> BBDTree<T> {
|
|||||||
mean[i] = node.sum[i] / T::from(node.count).unwrap();
|
mean[i] = node.sum[i] / T::from(node.count).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean) + BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
|
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean)
|
||||||
|
+ BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
|
||||||
|
|
||||||
self.add_node(node)
|
self.add_node(node)
|
||||||
}
|
}
|
||||||
@@ -284,7 +337,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
|
|
||||||
let data = DenseMatrix::from_array(&[
|
let data = DenseMatrix::from_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
@@ -305,19 +357,14 @@ mod tests {
|
|||||||
&[6.3, 3.3, 4.7, 1.6],
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
&[4.9, 2.4, 3.3, 1.0],
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
&[6.6, 2.9, 4.6, 1.3],
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
&[5.2, 2.7, 3.9, 1.4]]);
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
|
]);
|
||||||
|
|
||||||
let tree = BBDTree::new(&data);
|
let tree = BBDTree::new(&data);
|
||||||
|
|
||||||
let centroids = vec![
|
let centroids = vec![vec![4.86, 3.22, 1.61, 0.29], vec![6.23, 2.92, 4.48, 1.42]];
|
||||||
vec![4.86, 3.22, 1.61, 0.29],
|
|
||||||
vec![6.23, 2.92, 4.48, 1.42]
|
|
||||||
];
|
|
||||||
|
|
||||||
let mut sums = vec![
|
let mut sums = vec![vec![0f64; 4], vec![0f64; 4]];
|
||||||
vec![0f64; 4],
|
|
||||||
vec![0f64; 4]
|
|
||||||
];
|
|
||||||
|
|
||||||
let mut counts = vec![11, 9];
|
let mut counts = vec![11, 9];
|
||||||
|
|
||||||
@@ -328,7 +375,5 @@ mod tests {
|
|||||||
assert!((sums[0][0] - 48.6).abs() < 1e-2);
|
assert!((sums[0][0] - 48.6).abs() < 1e-2);
|
||||||
assert!((sums[1][3] - 13.8).abs() < 1e-2);
|
assert!((sums[1][3] - 13.8).abs() < 1e-2);
|
||||||
assert_eq!(membership[17], 1);
|
assert_eq!(membership[17], 1);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,41 +1,37 @@
|
|||||||
use std::collections::{HashMap, HashSet};
|
|
||||||
use std::iter::FromIterator;
|
|
||||||
use std::fmt::Debug;
|
|
||||||
use core::hash::{Hash, Hasher};
|
use core::hash::{Hash, Hasher};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::fmt::Debug;
|
||||||
|
use std::iter::FromIterator;
|
||||||
|
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::math::distance::Distance;
|
|
||||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||||
|
use crate::math::distance::Distance;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct CoverTree<T, F: FloatExt, D: Distance<T, F>>
|
pub struct CoverTree<T, F: FloatExt, D: Distance<T, F>> {
|
||||||
{
|
|
||||||
base: F,
|
base: F,
|
||||||
max_level: i8,
|
max_level: i8,
|
||||||
min_level: i8,
|
min_level: i8,
|
||||||
distance: D,
|
distance: D,
|
||||||
nodes: Vec<Node<T>>
|
nodes: Vec<Node<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D> {
|
||||||
{
|
|
||||||
|
|
||||||
pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
|
pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
|
||||||
let mut tree = CoverTree {
|
let mut tree = CoverTree {
|
||||||
base: F::two(),
|
base: F::two(),
|
||||||
max_level: 100,
|
max_level: 100,
|
||||||
min_level: 100,
|
min_level: 100,
|
||||||
distance: distance,
|
distance: distance,
|
||||||
nodes: Vec::new()
|
nodes: Vec::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let p = tree.new_node(None, data.remove(0));
|
let p = tree.new_node(None, data.remove(0));
|
||||||
tree.construct(p, data, Vec::new(), 10);
|
tree.construct(p, data, Vec::new(), 10);
|
||||||
|
|
||||||
tree
|
tree
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert(&mut self, p: T) {
|
pub fn insert(&mut self, p: T) {
|
||||||
@@ -44,14 +40,14 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
} else {
|
} else {
|
||||||
let mut parent: Option<NodeId> = Option::None;
|
let mut parent: Option<NodeId> = Option::None;
|
||||||
let mut p_i = 0;
|
let mut p_i = 0;
|
||||||
let mut qi_p_ds = vec!((self.root(), self.distance.distance(&p, &self.root().data)));
|
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
|
||||||
let mut i = self.max_level;
|
let mut i = self.max_level;
|
||||||
loop {
|
loop {
|
||||||
let i_d = self.base.powf(F::from(i).unwrap());
|
let i_d = self.base.powf(F::from(i).unwrap());
|
||||||
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||||
let d_p_q = self.min_by_distance(&q_p_ds);
|
let d_p_q = self.min_by_distance(&q_p_ds);
|
||||||
if d_p_q < F::epsilon() {
|
if d_p_q < F::epsilon() {
|
||||||
return
|
return;
|
||||||
} else if d_p_q > i_d {
|
} else if d_p_q > i_d {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -73,30 +69,40 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
pub fn new_node(&mut self, parent: Option<NodeId>, data: T) -> NodeId {
|
pub fn new_node(&mut self, parent: Option<NodeId>, data: T) -> NodeId {
|
||||||
let next_index = self.nodes.len();
|
let next_index = self.nodes.len();
|
||||||
let node_id = NodeId { index: next_index };
|
let node_id = NodeId { index: next_index };
|
||||||
self.nodes.push(
|
self.nodes.push(Node {
|
||||||
Node {
|
|
||||||
index: node_id,
|
index: node_id,
|
||||||
data: data,
|
data: data,
|
||||||
parent: parent,
|
parent: parent,
|
||||||
children: HashMap::new()
|
children: HashMap::new(),
|
||||||
});
|
});
|
||||||
node_id
|
node_id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn find(&self, p: &T, k: usize) -> Vec<usize> {
|
pub fn find(&self, p: &T, k: usize) -> Vec<usize> {
|
||||||
let mut qi_p_ds = vec!((self.root(), self.distance.distance(&p, &self.root().data)));
|
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
|
||||||
for i in (self.min_level..self.max_level + 1).rev() {
|
for i in (self.min_level..self.max_level + 1).rev() {
|
||||||
let i_d = self.base.powf(F::from(i).unwrap());
|
let i_d = self.base.powf(F::from(i).unwrap());
|
||||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||||
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
||||||
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
|
qi_p_ds = q_p_ds
|
||||||
|
.into_iter()
|
||||||
|
.filter(|(_, d)| d <= &(d_p_q + i_d))
|
||||||
|
.collect();
|
||||||
}
|
}
|
||||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
|
qi_p_ds[..usize::min(qi_p_ds.len(), k)]
|
||||||
|
.iter()
|
||||||
|
.map(|(n, _)| n.index.index)
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split(&self, p_id: NodeId, r: F, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){
|
fn split(
|
||||||
|
&self,
|
||||||
|
p_id: NodeId,
|
||||||
|
r: F,
|
||||||
|
s1: &mut Vec<T>,
|
||||||
|
s2: Option<&mut Vec<T>>,
|
||||||
|
) -> (Vec<T>, Vec<T>) {
|
||||||
let mut my_near = (Vec::new(), Vec::new());
|
let mut my_near = (Vec::new(), Vec::new());
|
||||||
|
|
||||||
my_near = self.split_remove_s(p_id, r, s1, my_near);
|
my_near = self.split_remove_s(p_id, r, s1, my_near);
|
||||||
@@ -105,12 +111,16 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
my_near = self.split_remove_s(p_id, r, s, my_near);
|
my_near = self.split_remove_s(p_id, r, s, my_near);
|
||||||
}
|
}
|
||||||
|
|
||||||
return my_near
|
return my_near;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split_remove_s(&self, p_id: NodeId, r: F, s: &mut Vec<T>, mut my_near: (Vec<T>, Vec<T>)) -> (Vec<T>, Vec<T>){
|
fn split_remove_s(
|
||||||
|
&self,
|
||||||
|
p_id: NodeId,
|
||||||
|
r: F,
|
||||||
|
s: &mut Vec<T>,
|
||||||
|
mut my_near: (Vec<T>, Vec<T>),
|
||||||
|
) -> (Vec<T>, Vec<T>) {
|
||||||
if s.len() > 0 {
|
if s.len() > 0 {
|
||||||
let p = &self.nodes.get(p_id.index).unwrap().data;
|
let p = &self.nodes.get(p_id.index).unwrap().data;
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
@@ -126,11 +136,16 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return my_near
|
return my_near;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn construct<'b>(&mut self, p: NodeId, mut near: Vec<T>, mut far: Vec<T>, i: i8) -> (NodeId, Vec<T>) {
|
fn construct<'b>(
|
||||||
|
&mut self,
|
||||||
|
p: NodeId,
|
||||||
|
mut near: Vec<T>,
|
||||||
|
mut far: Vec<T>,
|
||||||
|
i: i8,
|
||||||
|
) -> (NodeId, Vec<T>) {
|
||||||
if near.len() < 1 {
|
if near.len() < 1 {
|
||||||
self.min_level = std::cmp::min(self.min_level, i);
|
self.min_level = std::cmp::min(self.min_level, i);
|
||||||
return (p, far);
|
return (p, far);
|
||||||
@@ -140,39 +155,57 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
while near.len() > 0 {
|
while near.len() > 0 {
|
||||||
let q_data = near.remove(0);
|
let q_data = near.remove(0);
|
||||||
let nn = self.new_node(Some(p), q_data);
|
let nn = self.new_node(Some(p), q_data);
|
||||||
let (my, n) = self.split(nn, self.base.powf(F::from(i-1).unwrap()), &mut near, Some(&mut far));
|
let (my, n) = self.split(
|
||||||
|
nn,
|
||||||
|
self.base.powf(F::from(i - 1).unwrap()),
|
||||||
|
&mut near,
|
||||||
|
Some(&mut far),
|
||||||
|
);
|
||||||
let (child, mut unused) = self.construct(nn, my, n, i - 1);
|
let (child, mut unused) = self.construct(nn, my, n, i - 1);
|
||||||
self.add_child(pi, child, i);
|
self.add_child(pi, child, i);
|
||||||
let new_near_far = self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None);
|
let new_near_far =
|
||||||
|
self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None);
|
||||||
near.extend(new_near_far.0);
|
near.extend(new_near_far.0);
|
||||||
far.extend(new_near_far.1);
|
far.extend(new_near_far.1);
|
||||||
}
|
}
|
||||||
self.min_level = std::cmp::min(self.min_level, i);
|
self.min_level = std::cmp::min(self.min_level, i);
|
||||||
return (pi, far);
|
return (pi, far);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_child(&mut self, parent: NodeId, node: NodeId, i: i8) {
|
fn add_child(&mut self, parent: NodeId, node: NodeId, i: i8) {
|
||||||
self.nodes.get_mut(parent.index).unwrap().children.insert(i, node);
|
self.nodes
|
||||||
|
.get_mut(parent.index)
|
||||||
|
.unwrap()
|
||||||
|
.children
|
||||||
|
.insert(i, node);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn root(&self) -> &Node<T> {
|
fn root(&self) -> &Node<T> {
|
||||||
self.nodes.first().unwrap()
|
self.nodes.first().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_children_dist<'b>(&'b self, p: &T, qi_p_ds: &Vec<(&'b Node<T>, F)>, i: i8) -> Vec<(&'b Node<T>, F)> {
|
fn get_children_dist<'b>(
|
||||||
|
&'b self,
|
||||||
|
p: &T,
|
||||||
|
qi_p_ds: &Vec<(&'b Node<T>, F)>,
|
||||||
|
i: i8,
|
||||||
|
) -> Vec<(&'b Node<T>, F)> {
|
||||||
let mut children = Vec::<(&'b Node<T>, F)>::new();
|
let mut children = Vec::<(&'b Node<T>, F)>::new();
|
||||||
|
|
||||||
children.extend(qi_p_ds.iter().cloned());
|
children.extend(qi_p_ds.iter().cloned());
|
||||||
|
|
||||||
let q: Vec<&Node<T>> = qi_p_ds.iter().flat_map(|(n, _)| self.get_child(n, i)).collect();
|
let q: Vec<&Node<T>> = qi_p_ds
|
||||||
|
.iter()
|
||||||
|
.flat_map(|(n, _)| self.get_child(n, i))
|
||||||
|
.collect();
|
||||||
|
|
||||||
children.extend(q.into_iter().map(|n| (n, self.distance.distance(&n.data, &p))));
|
children.extend(
|
||||||
|
q.into_iter()
|
||||||
|
.map(|n| (n, self.distance.distance(&n.data, &p))),
|
||||||
|
);
|
||||||
|
|
||||||
children
|
children
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, F)>, k: usize) -> F {
|
fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, F)>, k: usize) -> F {
|
||||||
@@ -185,15 +218,24 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, F)>) -> F {
|
fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, F)>) -> F {
|
||||||
q_p_ds.into_iter().min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()).unwrap().1
|
q_p_ds
|
||||||
|
.into_iter()
|
||||||
|
.min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap())
|
||||||
|
.unwrap()
|
||||||
|
.1
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_child(&self, node: &Node<T>, i: i8) -> Option<&Node<T>> {
|
fn get_child(&self, node: &Node<T>, i: i8) -> Option<&Node<T>> {
|
||||||
node.children.get(&i).and_then(|n_id| self.nodes.get(n_id.index))
|
node.children
|
||||||
|
.get(&i)
|
||||||
|
.and_then(|n_id| self.nodes.get(n_id.index))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn check_invariant(&self, invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
|
fn check_invariant(
|
||||||
|
&self,
|
||||||
|
invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> (),
|
||||||
|
) {
|
||||||
let mut current_nodes: Vec<&Node<T>> = Vec::new();
|
let mut current_nodes: Vec<&Node<T>> = Vec::new();
|
||||||
current_nodes.push(self.root());
|
current_nodes.push(self.root());
|
||||||
for i in (self.min_level..self.max_level + 1).rev() {
|
for i in (self.min_level..self.max_level + 1).rev() {
|
||||||
@@ -206,16 +248,27 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn nesting_invariant(_: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
|
fn nesting_invariant(
|
||||||
|
_: &CoverTree<T, F, D>,
|
||||||
|
nodes: &Vec<&Node<T>>,
|
||||||
|
next_nodes: &Vec<&Node<T>>,
|
||||||
|
_: i8,
|
||||||
|
) {
|
||||||
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
|
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
|
||||||
let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
|
let next_nodes_set: HashSet<&Node<T>> =
|
||||||
|
HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
|
||||||
for n in nodes_set.iter() {
|
for n in nodes_set.iter() {
|
||||||
assert!(next_nodes_set.contains(n), "Nesting invariant of the cover tree is not satisfied. Set of nodes [{:?}] is not a subset of [{:?}]", nodes_set, next_nodes_set);
|
assert!(next_nodes_set.contains(n), "Nesting invariant of the cover tree is not satisfied. Set of nodes [{:?}] is not a subset of [{:?}]", nodes_set, next_nodes_set);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn covering_tree(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
|
fn covering_tree(
|
||||||
|
tree: &CoverTree<T, F, D>,
|
||||||
|
nodes: &Vec<&Node<T>>,
|
||||||
|
next_nodes: &Vec<&Node<T>>,
|
||||||
|
i: i8,
|
||||||
|
) {
|
||||||
let mut p_selected: Vec<&Node<T>> = Vec::new();
|
let mut p_selected: Vec<&Node<T>> = Vec::new();
|
||||||
for p in next_nodes {
|
for p in next_nodes {
|
||||||
for q in nodes {
|
for q in nodes {
|
||||||
@@ -223,7 +276,10 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
p_selected.push(*p);
|
p_selected.push(*p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let c = p_selected.iter().filter(|q| p.parent.map(|p| q.index == p).unwrap_or(false)).count();
|
let c = p_selected
|
||||||
|
.iter()
|
||||||
|
.filter(|q| p.parent.map(|p| q.index == p).unwrap_or(false))
|
||||||
|
.count();
|
||||||
assert!(c <= 1);
|
assert!(c <= 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -233,12 +289,14 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
for p in nodes {
|
for p in nodes {
|
||||||
for q in nodes {
|
for q in nodes {
|
||||||
if p != q {
|
if p != q {
|
||||||
assert!(tree.distance.distance(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
|
assert!(
|
||||||
|
tree.distance.distance(&p.data, &q.data)
|
||||||
|
> tree.base.powf(F::from(i).unwrap())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||||
@@ -251,7 +309,7 @@ struct Node<T> {
|
|||||||
index: NodeId,
|
index: NodeId,
|
||||||
data: T,
|
data: T,
|
||||||
children: HashMap<i8, NodeId>,
|
children: HashMap<i8, NodeId>,
|
||||||
parent: Option<NodeId>
|
parent: Option<NodeId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> PartialEq for Node<T> {
|
impl<T> PartialEq for Node<T> {
|
||||||
@@ -287,10 +345,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cover_tree_test() {
|
fn cover_tree_test() {
|
||||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
|
||||||
let mut tree = CoverTree::new(data, SimpleDistance {});
|
let mut tree = CoverTree::new(data, SimpleDistance {});
|
||||||
for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) {
|
for d in vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19] {
|
||||||
tree.insert(d);
|
tree.insert(d);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -308,12 +366,11 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_invariants() {
|
fn test_invariants() {
|
||||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
|
||||||
let tree = CoverTree::new(data, SimpleDistance {});
|
let tree = CoverTree::new(data, SimpleDistance {});
|
||||||
tree.check_invariant(CoverTree::nesting_invariant);
|
tree.check_invariant(CoverTree::nesting_invariant);
|
||||||
tree.check_invariant(CoverTree::covering_tree);
|
tree.check_invariant(CoverTree::covering_tree);
|
||||||
tree.check_invariant(CoverTree::separation);
|
tree.check_invariant(CoverTree::separation);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,16 +1,16 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::cmp::{Ordering, PartialOrd};
|
use std::cmp::{Ordering, PartialOrd};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use serde::{Serialize, Deserialize};
|
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::math::distance::Distance;
|
|
||||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||||
|
use crate::math::distance::Distance;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct LinearKNNSearch<T, F: FloatExt, D: Distance<T, F>> {
|
pub struct LinearKNNSearch<T, F: FloatExt, D: Distance<T, F>> {
|
||||||
distance: D,
|
distance: D,
|
||||||
data: Vec<T>,
|
data: Vec<T>,
|
||||||
f: PhantomData<F>
|
f: PhantomData<F>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||||
@@ -18,7 +18,7 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
LinearKNNSearch {
|
LinearKNNSearch {
|
||||||
data: data,
|
data: data,
|
||||||
distance: distance,
|
distance: distance,
|
||||||
f: PhantomData
|
f: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,12 +32,11 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
for _ in 0..k {
|
for _ in 0..k {
|
||||||
heap.add(KNNPoint {
|
heap.add(KNNPoint {
|
||||||
distance: F::infinity(),
|
distance: F::infinity(),
|
||||||
index: None
|
index: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
for i in 0..self.data.len() {
|
for i in 0..self.data.len() {
|
||||||
|
|
||||||
let d = self.distance.distance(&from, &self.data[i]);
|
let d = self.distance.distance(&from, &self.data[i]);
|
||||||
let datum = heap.peek_mut();
|
let datum = heap.peek_mut();
|
||||||
if d < datum.distance {
|
if d < datum.distance {
|
||||||
@@ -56,7 +55,7 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct KNNPoint<F: FloatExt> {
|
struct KNNPoint<F: FloatExt> {
|
||||||
distance: F,
|
distance: F,
|
||||||
index: Option<usize>
|
index: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: FloatExt> PartialOrd for KNNPoint<F> {
|
impl<F: FloatExt> PartialOrd for KNNPoint<F> {
|
||||||
@@ -88,13 +87,19 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_find() {
|
fn knn_find() {
|
||||||
let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||||
|
|
||||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
|
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
|
||||||
|
|
||||||
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
||||||
|
|
||||||
let data2 = vec!(vec![1., 1.], vec![2., 2.], vec![3., 3.], vec![4., 4.], vec![5., 5.]);
|
let data2 = vec![
|
||||||
|
vec![1., 1.],
|
||||||
|
vec![2., 2.],
|
||||||
|
vec![3., 3.],
|
||||||
|
vec![4., 4.],
|
||||||
|
vec![5., 5.],
|
||||||
|
];
|
||||||
|
|
||||||
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
||||||
|
|
||||||
@@ -105,22 +110,22 @@ mod tests {
|
|||||||
fn knn_point_eq() {
|
fn knn_point_eq() {
|
||||||
let point1 = KNNPoint {
|
let point1 = KNNPoint {
|
||||||
distance: 10.,
|
distance: 10.,
|
||||||
index: Some(0)
|
index: Some(0),
|
||||||
};
|
};
|
||||||
|
|
||||||
let point2 = KNNPoint {
|
let point2 = KNNPoint {
|
||||||
distance: 100.,
|
distance: 100.,
|
||||||
index: Some(1)
|
index: Some(1),
|
||||||
};
|
};
|
||||||
|
|
||||||
let point3 = KNNPoint {
|
let point3 = KNNPoint {
|
||||||
distance: 10.,
|
distance: 10.,
|
||||||
index: Some(2)
|
index: Some(2),
|
||||||
};
|
};
|
||||||
|
|
||||||
let point_inf = KNNPoint {
|
let point_inf = KNNPoint {
|
||||||
distance: std::f64::INFINITY,
|
distance: std::f64::INFINITY,
|
||||||
index: Some(3)
|
index: Some(3),
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(point2 > point1);
|
assert!(point2 > point1);
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
|
pub mod bbd_tree;
|
||||||
pub mod cover_tree;
|
pub mod cover_tree;
|
||||||
pub mod linear_search;
|
pub mod linear_search;
|
||||||
pub mod bbd_tree;
|
|
||||||
@@ -5,17 +5,16 @@ pub struct HeapSelect<T: PartialOrd> {
|
|||||||
k: usize,
|
k: usize,
|
||||||
n: usize,
|
n: usize,
|
||||||
sorted: bool,
|
sorted: bool,
|
||||||
heap: Vec<T>
|
heap: Vec<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: PartialOrd> HeapSelect<T> {
|
impl<'a, T: PartialOrd> HeapSelect<T> {
|
||||||
|
|
||||||
pub fn with_capacity(k: usize) -> HeapSelect<T> {
|
pub fn with_capacity(k: usize) -> HeapSelect<T> {
|
||||||
HeapSelect {
|
HeapSelect {
|
||||||
k: k,
|
k: k,
|
||||||
n: 0,
|
n: 0,
|
||||||
sorted: false,
|
sorted: false,
|
||||||
heap: Vec::<T>::new()
|
heap: Vec::<T>::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,7 +62,6 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
|
|||||||
self.heap.swap(k, j);
|
self.heap.swap(k, j);
|
||||||
k = j;
|
k = j;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(self) -> Vec<T> {
|
pub fn get(self) -> Vec<T> {
|
||||||
@@ -95,7 +93,6 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
|
|||||||
inc /= 3
|
inc /= 3
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -150,5 +147,4 @@ mod tests {
|
|||||||
HeapSelect::shuffle_sort(&mut v3, 3);
|
HeapSelect::shuffle_sort(&mut v3, 3);
|
||||||
assert_eq!(vec![3, 4, 5, 2, 1], v3);
|
assert_eq!(vec![3, 4, 5, 2, 1], v3);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -5,7 +5,6 @@ pub trait QuickArgSort {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Float> QuickArgSort for Vec<T> {
|
impl<T: Float> QuickArgSort for Vec<T> {
|
||||||
|
|
||||||
fn quick_argsort(&mut self) -> Vec<usize> {
|
fn quick_argsort(&mut self) -> Vec<usize> {
|
||||||
let stack_size = 64;
|
let stack_size = 64;
|
||||||
let mut jstack = -1;
|
let mut jstack = -1;
|
||||||
@@ -112,7 +111,13 @@ mod tests {
|
|||||||
let mut arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
|
let mut arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
|
||||||
assert_eq!(vec![1, 2, 0, 3, 5, 7, 6, 8, 4], arr1.quick_argsort());
|
assert_eq!(vec![1, 2, 0, 3, 5, 7, 6, 8, 4], arr1.quick_argsort());
|
||||||
|
|
||||||
let mut arr2 = vec![0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6, 1.0, 1.3, 1.4];
|
let mut arr2 = vec![
|
||||||
assert_eq!(vec![9, 7, 1, 8, 0, 2, 4, 3, 6, 5, 17, 18, 15, 13, 19, 10, 14, 11, 12, 16], arr2.quick_argsort());
|
0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6,
|
||||||
|
1.0, 1.3, 1.4,
|
||||||
|
];
|
||||||
|
assert_eq!(
|
||||||
|
vec![9, 7, 1, 8, 0, 2, 4, 3, 6, 5, 17, 18, 15, 13, 19, 10, 14, 11, 12, 16],
|
||||||
|
arr2.quick_argsort()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
+24
-27
@@ -1,15 +1,15 @@
|
|||||||
extern crate rand;
|
extern crate rand;
|
||||||
|
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use std::iter::Sum;
|
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
use std::iter::Sum;
|
||||||
|
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::distance::euclidian::*;
|
use crate::math::distance::euclidian::*;
|
||||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct KMeans<T: FloatExt> {
|
pub struct KMeans<T: FloatExt> {
|
||||||
@@ -17,24 +17,25 @@ pub struct KMeans<T: FloatExt> {
|
|||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
size: Vec<usize>,
|
size: Vec<usize>,
|
||||||
distortion: T,
|
distortion: T,
|
||||||
centroids: Vec<Vec<T>>
|
centroids: Vec<Vec<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> PartialEq for KMeans<T> {
|
impl<T: FloatExt> PartialEq for KMeans<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.k != other.k ||
|
if self.k != other.k
|
||||||
self.size != other.size ||
|
|| self.size != other.size
|
||||||
self.centroids.len() != other.centroids.len() {
|
|| self.centroids.len() != other.centroids.len()
|
||||||
|
{
|
||||||
false
|
false
|
||||||
} else {
|
} else {
|
||||||
let n_centroids = self.centroids.len();
|
let n_centroids = self.centroids.len();
|
||||||
for i in 0..n_centroids {
|
for i in 0..n_centroids {
|
||||||
if self.centroids[i].len() != other.centroids[i].len() {
|
if self.centroids[i].len() != other.centroids[i].len() {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
for j in 0..self.centroids[i].len() {
|
for j in 0..self.centroids[i].len() {
|
||||||
if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() {
|
if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -45,20 +46,17 @@ impl<T: FloatExt> PartialEq for KMeans<T> {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct KMeansParameters {
|
pub struct KMeansParameters {
|
||||||
pub max_iter: usize
|
pub max_iter: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for KMeansParameters {
|
impl Default for KMeansParameters {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
KMeansParameters {
|
KMeansParameters { max_iter: 100 }
|
||||||
max_iter: 100
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Sum> KMeans<T> {
|
impl<T: FloatExt + Sum> KMeans<T> {
|
||||||
pub fn new<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans<T> {
|
pub fn new<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans<T> {
|
||||||
|
|
||||||
let bbd = BBDTree::new(data);
|
let bbd = BBDTree::new(data);
|
||||||
|
|
||||||
if k < 2 {
|
if k < 2 {
|
||||||
@@ -66,7 +64,10 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if parameters.max_iter <= 0 {
|
if parameters.max_iter <= 0 {
|
||||||
panic!("Invalid maximum number of iterations: {}", parameters.max_iter);
|
panic!(
|
||||||
|
"Invalid maximum number of iterations: {}",
|
||||||
|
parameters.max_iter
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let (n, d) = data.shape();
|
let (n, d) = data.shape();
|
||||||
@@ -108,7 +109,6 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
|||||||
} else {
|
} else {
|
||||||
distortion = dist;
|
distortion = dist;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
KMeans {
|
KMeans {
|
||||||
@@ -116,7 +116,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
|||||||
y: y,
|
y: y,
|
||||||
size: size,
|
size: size,
|
||||||
distortion: distortion,
|
distortion: distortion,
|
||||||
centroids: centroids
|
centroids: centroids,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +125,6 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
|||||||
let mut result = M::zeros(1, n);
|
let mut result = M::zeros(1, n);
|
||||||
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
|
|
||||||
let mut min_dist = T::max_value();
|
let mut min_dist = T::max_value();
|
||||||
let mut best_cluster = 0;
|
let mut best_cluster = 0;
|
||||||
|
|
||||||
@@ -193,10 +192,8 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
|||||||
|
|
||||||
y
|
y
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -224,7 +221,8 @@ mod tests {
|
|||||||
&[6.3, 3.3, 4.7, 1.6],
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
&[4.9, 2.4, 3.3, 1.0],
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
&[6.6, 2.9, 4.6, 1.3],
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
&[5.2, 2.7, 3.9, 1.4]]);
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
|
]);
|
||||||
|
|
||||||
let kmeans = KMeans::new(&x, 2, Default::default());
|
let kmeans = KMeans::new(&x, 2, Default::default());
|
||||||
|
|
||||||
@@ -233,7 +231,6 @@ mod tests {
|
|||||||
for i in 0..y.len() {
|
for i in 0..y.len() {
|
||||||
assert_eq!(y[i] as usize, kmeans.y[i]);
|
assert_eq!(y[i] as usize, kmeans.y[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -258,14 +255,14 @@ mod tests {
|
|||||||
&[6.3, 3.3, 4.7, 1.6],
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
&[4.9, 2.4, 3.3, 1.0],
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
&[6.6, 2.9, 4.6, 1.3],
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
&[5.2, 2.7, 3.9, 1.4]]);
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
|
]);
|
||||||
|
|
||||||
let kmeans = KMeans::new(&x, 2, Default::default());
|
let kmeans = KMeans::new(&x, 2, Default::default());
|
||||||
|
|
||||||
let deserialized_kmeans: KMeans<f64> = serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
|
let deserialized_kmeans: KMeans<f64> =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(kmeans, deserialized_kmeans);
|
assert_eq!(kmeans, deserialized_kmeans);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
+11
-6
@@ -1,13 +1,18 @@
|
|||||||
use num_traits::{Num, ToPrimitive, FromPrimitive, Zero, One};
|
use ndarray::ScalarOperand;
|
||||||
use ndarray::{ScalarOperand};
|
use num_traits::{FromPrimitive, Num, One, ToPrimitive, Zero};
|
||||||
use std::hash::Hash;
|
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
use std::hash::Hash;
|
||||||
|
|
||||||
pub trait AnyNumber: Num + ScalarOperand + ToPrimitive + FromPrimitive {}
|
pub trait AnyNumber: Num + ScalarOperand + ToPrimitive + FromPrimitive {}
|
||||||
|
|
||||||
pub trait Nominal: PartialEq + Zero + One + Eq + Hash + ToPrimitive + FromPrimitive + Debug + 'static + Clone{}
|
pub trait Nominal:
|
||||||
|
PartialEq + Zero + One + Eq + Hash + ToPrimitive + FromPrimitive + Debug + 'static + Clone
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
impl<T> AnyNumber for T where T: Num + ScalarOperand + ToPrimitive + FromPrimitive {}
|
impl<T> AnyNumber for T where T: Num + ScalarOperand + ToPrimitive + FromPrimitive {}
|
||||||
|
|
||||||
impl<T> Nominal for T where T: PartialEq + Zero + One + Eq + Hash + ToPrimitive + Debug + FromPrimitive + 'static + Clone {}
|
impl<T> Nominal for T where
|
||||||
|
T: PartialEq + Zero + One + Eq + Hash + ToPrimitive + Debug + FromPrimitive + 'static + Clone
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|||||||
+106
-44
@@ -1,9 +1,9 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::{Matrix};
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct PCA<T: FloatExt, M: Matrix<T>> {
|
pub struct PCA<T: FloatExt, M: Matrix<T>> {
|
||||||
@@ -11,42 +11,41 @@ pub struct PCA<T: FloatExt, M: Matrix<T>> {
|
|||||||
eigenvalues: Vec<T>,
|
eigenvalues: Vec<T>,
|
||||||
projection: M,
|
projection: M,
|
||||||
mu: Vec<T>,
|
mu: Vec<T>,
|
||||||
pmu: Vec<T>
|
pmu: Vec<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for PCA<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> PartialEq for PCA<T, M> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.eigenvectors != other.eigenvectors ||
|
if self.eigenvectors != other.eigenvectors
|
||||||
self.eigenvalues.len() != other.eigenvalues.len() {
|
|| self.eigenvalues.len() != other.eigenvalues.len()
|
||||||
return false
|
{
|
||||||
|
return false;
|
||||||
} else {
|
} else {
|
||||||
for i in 0..self.eigenvalues.len() {
|
for i in 0..self.eigenvalues.len() {
|
||||||
if (self.eigenvalues[i] - other.eigenvalues[i]).abs() > T::epsilon() {
|
if (self.eigenvalues[i] - other.eigenvalues[i]).abs() > T::epsilon() {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct PCAParameters {
|
pub struct PCAParameters {
|
||||||
use_correlation_matrix: bool
|
use_correlation_matrix: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for PCAParameters {
|
impl Default for PCAParameters {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
PCAParameters {
|
PCAParameters {
|
||||||
use_correlation_matrix: false
|
use_correlation_matrix: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
|
||||||
|
|
||||||
pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA<T, M> {
|
pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA<T, M> {
|
||||||
|
|
||||||
let (m, n) = data.shape();
|
let (m, n) = data.shape();
|
||||||
|
|
||||||
let mu = data.column_mean();
|
let mu = data.column_mean();
|
||||||
@@ -63,7 +62,6 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
|
|||||||
let mut eigenvectors;
|
let mut eigenvectors;
|
||||||
|
|
||||||
if m > n && !parameters.use_correlation_matrix {
|
if m > n && !parameters.use_correlation_matrix {
|
||||||
|
|
||||||
let svd = x.svd();
|
let svd = x.svd();
|
||||||
eigenvalues = svd.s;
|
eigenvalues = svd.s;
|
||||||
for i in 0..eigenvalues.len() {
|
for i in 0..eigenvalues.len() {
|
||||||
@@ -114,13 +112,11 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
let evd = cov.evd(true);
|
let evd = cov.evd(true);
|
||||||
|
|
||||||
eigenvalues = evd.d;
|
eigenvalues = evd.d;
|
||||||
|
|
||||||
eigenvectors = evd.V;
|
eigenvectors = evd.V;
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,7 +139,7 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
|
|||||||
eigenvalues: eigenvalues,
|
eigenvalues: eigenvalues,
|
||||||
projection: projection.transpose(),
|
projection: projection.transpose(),
|
||||||
mu: mu,
|
mu: mu,
|
||||||
pmu: pmu
|
pmu: pmu,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,7 +147,11 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
|
|||||||
let (nrows, ncols) = x.shape();
|
let (nrows, ncols) = x.shape();
|
||||||
let (_, n_components) = self.projection.shape();
|
let (_, n_components) = self.projection.shape();
|
||||||
if ncols != self.mu.len() {
|
if ncols != self.mu.len() {
|
||||||
panic!("Invalid input vector size: {}, expected: {}", ncols, self.mu.len());
|
panic!(
|
||||||
|
"Invalid input vector size: {}, expected: {}",
|
||||||
|
ncols,
|
||||||
|
self.mu.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut x_transformed = x.dot(&self.projection);
|
let mut x_transformed = x.dot(&self.projection);
|
||||||
@@ -162,7 +162,6 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
|
|||||||
}
|
}
|
||||||
x_transformed
|
x_transformed
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -221,19 +220,39 @@ mod tests {
|
|||||||
&[4.0, 145.0, 73.0, 26.2],
|
&[4.0, 145.0, 73.0, 26.2],
|
||||||
&[5.7, 81.0, 39.0, 9.3],
|
&[5.7, 81.0, 39.0, 9.3],
|
||||||
&[2.6, 53.0, 66.0, 10.8],
|
&[2.6, 53.0, 66.0, 10.8],
|
||||||
&[6.8, 161.0, 60.0, 15.6]])
|
&[6.8, 161.0, 60.0, 15.6],
|
||||||
|
])
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_covariance() {
|
fn decompose_covariance() {
|
||||||
|
|
||||||
let us_arrests = us_arrests_data();
|
let us_arrests = us_arrests_data();
|
||||||
|
|
||||||
let expected_eigenvectors = DenseMatrix::from_array(&[
|
let expected_eigenvectors = DenseMatrix::from_array(&[
|
||||||
&[-0.0417043206282872, -0.0448216562696701, -0.0798906594208108, -0.994921731246978],
|
&[
|
||||||
&[-0.995221281426497, -0.058760027857223, 0.0675697350838043, 0.0389382976351601],
|
-0.0417043206282872,
|
||||||
&[-0.0463357461197108, 0.97685747990989, 0.200546287353866, -0.0581691430589319],
|
-0.0448216562696701,
|
||||||
&[-0.075155500585547, 0.200718066450337, -0.974080592182491, 0.0723250196376097]
|
-0.0798906594208108,
|
||||||
|
-0.994921731246978,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.995221281426497,
|
||||||
|
-0.058760027857223,
|
||||||
|
0.0675697350838043,
|
||||||
|
0.0389382976351601,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.0463357461197108,
|
||||||
|
0.97685747990989,
|
||||||
|
0.200546287353866,
|
||||||
|
-0.0581691430589319,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.075155500585547,
|
||||||
|
0.200718066450337,
|
||||||
|
-0.974080592182491,
|
||||||
|
0.0723250196376097,
|
||||||
|
],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let expected_projection = DenseMatrix::from_array(&[
|
let expected_projection = DenseMatrix::from_array(&[
|
||||||
@@ -286,14 +305,22 @@ mod tests {
|
|||||||
&[25.0758, 9.968, -4.7811, 2.6911],
|
&[25.0758, 9.968, -4.7811, 2.6911],
|
||||||
&[91.5446, -22.9529, 0.402, -0.7369],
|
&[91.5446, -22.9529, 0.402, -0.7369],
|
||||||
&[118.1763, 5.5076, 2.7113, -0.205],
|
&[118.1763, 5.5076, 2.7113, -0.205],
|
||||||
&[10.4345, -5.9245, 3.7944, 0.5179]
|
&[10.4345, -5.9245, 3.7944, 0.5179],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let expected_eigenvalues: Vec<f64> = vec![343544.6277001563, 9897.625949808047, 2063.519887011604, 302.04806302399646];
|
let expected_eigenvalues: Vec<f64> = vec![
|
||||||
|
343544.6277001563,
|
||||||
|
9897.625949808047,
|
||||||
|
2063.519887011604,
|
||||||
|
302.04806302399646,
|
||||||
|
];
|
||||||
|
|
||||||
let pca = PCA::new(&us_arrests, 4, Default::default());
|
let pca = PCA::new(&us_arrests, 4, Default::default());
|
||||||
|
|
||||||
assert!(pca.eigenvectors.abs().approximate_eq(&expected_eigenvectors.abs(), 1e-4));
|
assert!(pca
|
||||||
|
.eigenvectors
|
||||||
|
.abs()
|
||||||
|
.approximate_eq(&expected_eigenvectors.abs(), 1e-4));
|
||||||
|
|
||||||
for i in 0..pca.eigenvalues.len() {
|
for i in 0..pca.eigenvalues.len() {
|
||||||
assert_eq!(pca.eigenvalues[i].abs(), expected_eigenvalues[i].abs());
|
assert_eq!(pca.eigenvalues[i].abs(), expected_eigenvalues[i].abs());
|
||||||
@@ -301,20 +328,40 @@ mod tests {
|
|||||||
|
|
||||||
let us_arrests_t = pca.transform(&us_arrests);
|
let us_arrests_t = pca.transform(&us_arrests);
|
||||||
|
|
||||||
assert!(us_arrests_t.abs().approximate_eq(&expected_projection.abs(), 1e-4));
|
assert!(us_arrests_t
|
||||||
|
.abs()
|
||||||
|
.approximate_eq(&expected_projection.abs(), 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_correlation() {
|
fn decompose_correlation() {
|
||||||
|
|
||||||
let us_arrests = us_arrests_data();
|
let us_arrests = us_arrests_data();
|
||||||
|
|
||||||
let expected_eigenvectors = DenseMatrix::from_array(&[
|
let expected_eigenvectors = DenseMatrix::from_array(&[
|
||||||
&[0.124288601688222, -0.0969866877028367, 0.0791404742697482, -0.150572299008293],
|
&[
|
||||||
&[0.00706888610512014, -0.00227861130898090, 0.00325028101296307, 0.00901099154845273],
|
0.124288601688222,
|
||||||
&[0.0194141494466002, 0.060910660326921, 0.0263806464184195, -0.0093429458365566],
|
-0.0969866877028367,
|
||||||
&[0.0586084532558777, 0.0180450999787168, -0.0881962972508558, -0.0096011588898465]
|
0.0791404742697482,
|
||||||
|
-0.150572299008293,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.00706888610512014,
|
||||||
|
-0.00227861130898090,
|
||||||
|
0.00325028101296307,
|
||||||
|
0.00901099154845273,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.0194141494466002,
|
||||||
|
0.060910660326921,
|
||||||
|
0.0263806464184195,
|
||||||
|
-0.0093429458365566,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.0586084532558777,
|
||||||
|
0.0180450999787168,
|
||||||
|
-0.0881962972508558,
|
||||||
|
-0.0096011588898465,
|
||||||
|
],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let expected_projection = DenseMatrix::from_array(&[
|
let expected_projection = DenseMatrix::from_array(&[
|
||||||
@@ -367,14 +414,28 @@ mod tests {
|
|||||||
&[-0.2169, 0.9701, -0.6249, 0.2208],
|
&[-0.2169, 0.9701, -0.6249, 0.2208],
|
||||||
&[-2.1086, -1.4248, -0.1048, -0.1319],
|
&[-2.1086, -1.4248, -0.1048, -0.1319],
|
||||||
&[-2.0797, 0.6113, 0.1389, -0.1841],
|
&[-2.0797, 0.6113, 0.1389, -0.1841],
|
||||||
&[-0.6294, -0.321, 0.2407, 0.1667]
|
&[-0.6294, -0.321, 0.2407, 0.1667],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let expected_eigenvalues: Vec<f64> = vec![2.480241579149493, 0.9897651525398419, 0.35656318058083064, 0.1734300877298357];
|
let expected_eigenvalues: Vec<f64> = vec![
|
||||||
|
2.480241579149493,
|
||||||
|
0.9897651525398419,
|
||||||
|
0.35656318058083064,
|
||||||
|
0.1734300877298357,
|
||||||
|
];
|
||||||
|
|
||||||
let pca = PCA::new(&us_arrests, 4, PCAParameters{use_correlation_matrix: true});
|
let pca = PCA::new(
|
||||||
|
&us_arrests,
|
||||||
|
4,
|
||||||
|
PCAParameters {
|
||||||
|
use_correlation_matrix: true,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
assert!(pca.eigenvectors.abs().approximate_eq(&expected_eigenvectors.abs(), 1e-4));
|
assert!(pca
|
||||||
|
.eigenvectors
|
||||||
|
.abs()
|
||||||
|
.approximate_eq(&expected_eigenvectors.abs(), 1e-4));
|
||||||
|
|
||||||
for i in 0..pca.eigenvalues.len() {
|
for i in 0..pca.eigenvalues.len() {
|
||||||
assert_eq!(pca.eigenvalues[i].abs(), expected_eigenvalues[i].abs());
|
assert_eq!(pca.eigenvalues[i].abs(), expected_eigenvalues[i].abs());
|
||||||
@@ -382,8 +443,9 @@ mod tests {
|
|||||||
|
|
||||||
let us_arrests_t = pca.transform(&us_arrests);
|
let us_arrests_t = pca.transform(&us_arrests);
|
||||||
|
|
||||||
assert!(us_arrests_t.abs().approximate_eq(&expected_projection.abs(), 1e-4));
|
assert!(us_arrests_t
|
||||||
|
.abs()
|
||||||
|
.approximate_eq(&expected_projection.abs(), 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -408,14 +470,14 @@ mod tests {
|
|||||||
&[6.3, 3.3, 4.7, 1.6],
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
&[4.9, 2.4, 3.3, 1.0],
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
&[6.6, 2.9, 4.6, 1.3],
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
&[5.2, 2.7, 3.9, 1.4]]);
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
|
]);
|
||||||
|
|
||||||
let pca = PCA::new(&iris, 4, Default::default());
|
let pca = PCA::new(&iris, 4, Default::default());
|
||||||
|
|
||||||
let deserialized_pca: PCA<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
|
let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(pca, deserialized_pca);
|
assert_eq!(pca, deserialized_pca);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -4,11 +4,13 @@ use std::default::Default;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
|
use crate::math::num::FloatExt;
|
||||||
|
use crate::tree::decision_tree_classifier::{
|
||||||
|
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct RandomForestClassifierParameters {
|
pub struct RandomForestClassifierParameters {
|
||||||
@@ -17,30 +19,29 @@ pub struct RandomForestClassifierParameters {
|
|||||||
pub min_samples_leaf: usize,
|
pub min_samples_leaf: usize,
|
||||||
pub min_samples_split: usize,
|
pub min_samples_split: usize,
|
||||||
pub n_trees: u16,
|
pub n_trees: u16,
|
||||||
pub mtry: Option<usize>
|
pub mtry: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct RandomForestClassifier<T: FloatExt> {
|
pub struct RandomForestClassifier<T: FloatExt> {
|
||||||
parameters: RandomForestClassifierParameters,
|
parameters: RandomForestClassifierParameters,
|
||||||
trees: Vec<DecisionTreeClassifier<T>>,
|
trees: Vec<DecisionTreeClassifier<T>>,
|
||||||
classes: Vec<T>
|
classes: Vec<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> PartialEq for RandomForestClassifier<T> {
|
impl<T: FloatExt> PartialEq for RandomForestClassifier<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.classes.len() != other.classes.len() ||
|
if self.classes.len() != other.classes.len() || self.trees.len() != other.trees.len() {
|
||||||
self.trees.len() != other.trees.len() {
|
return false;
|
||||||
return false
|
|
||||||
} else {
|
} else {
|
||||||
for i in 0..self.classes.len() {
|
for i in 0..self.classes.len() {
|
||||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i in 0..self.trees.len() {
|
for i in 0..self.trees.len() {
|
||||||
if self.trees[i] != other.trees[i] {
|
if self.trees[i] != other.trees[i] {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
true
|
true
|
||||||
@@ -56,14 +57,17 @@ impl Default for RandomForestClassifierParameters {
|
|||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 100,
|
n_trees: 100,
|
||||||
mtry: Option::None
|
mtry: Option::None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> RandomForestClassifier<T> {
|
impl<T: FloatExt> RandomForestClassifier<T> {
|
||||||
|
pub fn fit<M: Matrix<T>>(
|
||||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: RandomForestClassifierParameters) -> RandomForestClassifier<T> {
|
x: &M,
|
||||||
|
y: &M::RowVector,
|
||||||
|
parameters: RandomForestClassifierParameters,
|
||||||
|
) -> RandomForestClassifier<T> {
|
||||||
let (_, num_attributes) = x.shape();
|
let (_, num_attributes) = x.shape();
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (_, y_ncols) = y_m.shape();
|
let (_, y_ncols) = y_m.shape();
|
||||||
@@ -75,7 +79,13 @@ impl<T: FloatExt> RandomForestClassifier<T> {
|
|||||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mtry = parameters.mtry.unwrap_or((T::from(num_attributes).unwrap()).sqrt().floor().to_usize().unwrap());
|
let mtry = parameters.mtry.unwrap_or(
|
||||||
|
(T::from(num_attributes).unwrap())
|
||||||
|
.sqrt()
|
||||||
|
.floor()
|
||||||
|
.to_usize()
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
let classes = y_m.unique();
|
let classes = y_m.unique();
|
||||||
let k = classes.len();
|
let k = classes.len();
|
||||||
@@ -87,7 +97,7 @@ impl<T: FloatExt> RandomForestClassifier<T> {
|
|||||||
criterion: parameters.criterion.clone(),
|
criterion: parameters.criterion.clone(),
|
||||||
max_depth: parameters.max_depth,
|
max_depth: parameters.max_depth,
|
||||||
min_samples_leaf: parameters.min_samples_leaf,
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
min_samples_split: parameters.min_samples_split
|
min_samples_split: parameters.min_samples_split,
|
||||||
};
|
};
|
||||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
|
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
|
||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
@@ -96,7 +106,7 @@ impl<T: FloatExt> RandomForestClassifier<T> {
|
|||||||
RandomForestClassifier {
|
RandomForestClassifier {
|
||||||
parameters: parameters,
|
parameters: parameters,
|
||||||
trees: trees,
|
trees: trees,
|
||||||
classes
|
classes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,8 +129,7 @@ impl<T: FloatExt> RandomForestClassifier<T> {
|
|||||||
result[tree.predict_for_row(x, row)] += 1;
|
result[tree.predict_for_row(x, row)] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return which_max(&result)
|
return which_max(&result);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<usize> {
|
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<usize> {
|
||||||
@@ -146,7 +155,6 @@ impl<T: FloatExt> RandomForestClassifier<T> {
|
|||||||
}
|
}
|
||||||
samples
|
samples
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -156,7 +164,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
@@ -177,20 +184,26 @@ mod tests {
|
|||||||
&[6.3, 3.3, 4.7, 1.6],
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
&[4.9, 2.4, 3.3, 1.0],
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
&[6.6, 2.9, 4.6, 1.3],
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
&[5.2, 2.7, 3.9, 1.4]]);
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
]);
|
||||||
|
let y = vec![
|
||||||
|
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
|
];
|
||||||
|
|
||||||
let classifier = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters{
|
let classifier = RandomForestClassifier::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
RandomForestClassifierParameters {
|
||||||
criterion: SplitCriterion::Gini,
|
criterion: SplitCriterion::Gini,
|
||||||
max_depth: None,
|
max_depth: None,
|
||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 1000,
|
n_trees: 1000,
|
||||||
mtry: Option::None
|
mtry: Option::None,
|
||||||
});
|
},
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(y, classifier.predict(&x));
|
assert_eq!(y, classifier.predict(&x));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -215,15 +228,17 @@ mod tests {
|
|||||||
&[6.3, 3.3, 4.7, 1.6],
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
&[4.9, 2.4, 3.3, 1.0],
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
&[6.6, 2.9, 4.6, 1.3],
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
&[5.2, 2.7, 3.9, 1.4]]);
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
]);
|
||||||
|
let y = vec![
|
||||||
|
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
|
];
|
||||||
|
|
||||||
let forest = RandomForestClassifier::fit(&x, &y, Default::default());
|
let forest = RandomForestClassifier::fit(&x, &y, Default::default());
|
||||||
|
|
||||||
let deserialized_forest: RandomForestClassifier<f64> = bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
let deserialized_forest: RandomForestClassifier<f64> =
|
||||||
|
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(forest, deserialized_forest);
|
assert_eq!(forest, deserialized_forest);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -4,11 +4,13 @@ use std::default::Default;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::tree::decision_tree_regressor::{DecisionTreeRegressor, DecisionTreeRegressorParameters};
|
use crate::math::num::FloatExt;
|
||||||
|
use crate::tree::decision_tree_regressor::{
|
||||||
|
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct RandomForestRegressorParameters {
|
pub struct RandomForestRegressorParameters {
|
||||||
@@ -16,13 +18,13 @@ pub struct RandomForestRegressorParameters {
|
|||||||
pub min_samples_leaf: usize,
|
pub min_samples_leaf: usize,
|
||||||
pub min_samples_split: usize,
|
pub min_samples_split: usize,
|
||||||
pub n_trees: usize,
|
pub n_trees: usize,
|
||||||
pub mtry: Option<usize>
|
pub mtry: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct RandomForestRegressor<T: FloatExt> {
|
pub struct RandomForestRegressor<T: FloatExt> {
|
||||||
parameters: RandomForestRegressorParameters,
|
parameters: RandomForestRegressorParameters,
|
||||||
trees: Vec<DecisionTreeRegressor<T>>
|
trees: Vec<DecisionTreeRegressor<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RandomForestRegressorParameters {
|
impl Default for RandomForestRegressorParameters {
|
||||||
@@ -32,7 +34,7 @@ impl Default for RandomForestRegressorParameters {
|
|||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 10,
|
n_trees: 10,
|
||||||
mtry: Option::None
|
mtry: Option::None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -40,11 +42,11 @@ impl Default for RandomForestRegressorParameters {
|
|||||||
impl<T: FloatExt> PartialEq for RandomForestRegressor<T> {
|
impl<T: FloatExt> PartialEq for RandomForestRegressor<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.trees.len() != other.trees.len() {
|
if self.trees.len() != other.trees.len() {
|
||||||
return false
|
return false;
|
||||||
} else {
|
} else {
|
||||||
for i in 0..self.trees.len() {
|
for i in 0..self.trees.len() {
|
||||||
if self.trees[i] != other.trees[i] {
|
if self.trees[i] != other.trees[i] {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
true
|
true
|
||||||
@@ -53,11 +55,16 @@ impl<T: FloatExt> PartialEq for RandomForestRegressor<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> RandomForestRegressor<T> {
|
impl<T: FloatExt> RandomForestRegressor<T> {
|
||||||
|
pub fn fit<M: Matrix<T>>(
|
||||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: RandomForestRegressorParameters) -> RandomForestRegressor<T> {
|
x: &M,
|
||||||
|
y: &M::RowVector,
|
||||||
|
parameters: RandomForestRegressorParameters,
|
||||||
|
) -> RandomForestRegressor<T> {
|
||||||
let (n_rows, num_attributes) = x.shape();
|
let (n_rows, num_attributes) = x.shape();
|
||||||
|
|
||||||
let mtry = parameters.mtry.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
let mtry = parameters
|
||||||
|
.mtry
|
||||||
|
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||||
|
|
||||||
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
||||||
|
|
||||||
@@ -66,7 +73,7 @@ impl<T: FloatExt> RandomForestRegressor<T> {
|
|||||||
let params = DecisionTreeRegressorParameters {
|
let params = DecisionTreeRegressorParameters {
|
||||||
max_depth: parameters.max_depth,
|
max_depth: parameters.max_depth,
|
||||||
min_samples_leaf: parameters.min_samples_leaf,
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
min_samples_split: parameters.min_samples_split
|
min_samples_split: parameters.min_samples_split,
|
||||||
};
|
};
|
||||||
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params);
|
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params);
|
||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
@@ -74,7 +81,7 @@ impl<T: FloatExt> RandomForestRegressor<T> {
|
|||||||
|
|
||||||
RandomForestRegressor {
|
RandomForestRegressor {
|
||||||
parameters: parameters,
|
parameters: parameters,
|
||||||
trees: trees
|
trees: trees,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,7 +98,6 @@ impl<T: FloatExt> RandomForestRegressor<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||||
|
|
||||||
let n_trees = self.trees.len();
|
let n_trees = self.trees.len();
|
||||||
|
|
||||||
let mut result = T::zero();
|
let mut result = T::zero();
|
||||||
@@ -101,7 +107,6 @@ impl<T: FloatExt> RandomForestRegressor<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
result / T::from(n_trees).unwrap()
|
result / T::from(n_trees).unwrap()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_with_replacement(nrows: usize) -> Vec<usize> {
|
fn sample_with_replacement(nrows: usize) -> Vec<usize> {
|
||||||
@@ -113,7 +118,6 @@ impl<T: FloatExt> RandomForestRegressor<T> {
|
|||||||
}
|
}
|
||||||
samples
|
samples
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -124,7 +128,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_longley() {
|
fn fit_longley() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
@@ -141,27 +144,37 @@ mod tests {
|
|||||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
]);
|
||||||
|
let y = vec![
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
];
|
||||||
|
|
||||||
let expected_y: Vec<f64> = vec![85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.];
|
let expected_y: Vec<f64> = vec![
|
||||||
|
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
|
||||||
|
];
|
||||||
|
|
||||||
let y_hat = RandomForestRegressor::fit(&x, &y,
|
let y_hat = RandomForestRegressor::fit(
|
||||||
RandomForestRegressorParameters{max_depth: None,
|
&x,
|
||||||
|
&y,
|
||||||
|
RandomForestRegressorParameters {
|
||||||
|
max_depth: None,
|
||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 1000,
|
n_trees: 1000,
|
||||||
mtry: Option::None}).predict(&x);
|
mtry: Option::None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.predict(&x);
|
||||||
|
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
|
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn my_fit_longley_ndarray() {
|
fn my_fit_longley_ndarray() {
|
||||||
|
|
||||||
let x = arr2(&[
|
let x = arr2(&[
|
||||||
[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
@@ -178,22 +191,33 @@ mod tests {
|
|||||||
[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
let y = arr1(&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]);
|
]);
|
||||||
|
let y = arr1(&[
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
]);
|
||||||
|
|
||||||
let expected_y: Vec<f64> = vec![85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.];
|
let expected_y: Vec<f64> = vec![
|
||||||
|
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
|
||||||
|
];
|
||||||
|
|
||||||
let y_hat = RandomForestRegressor::fit(&x, &y,
|
let y_hat = RandomForestRegressor::fit(
|
||||||
RandomForestRegressorParameters{max_depth: None,
|
&x,
|
||||||
|
&y,
|
||||||
|
RandomForestRegressorParameters {
|
||||||
|
max_depth: None,
|
||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 1000,
|
n_trees: 1000,
|
||||||
mtry: Option::None}).predict(&x);
|
mtry: Option::None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.predict(&x);
|
||||||
|
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
|
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -214,15 +238,18 @@ mod tests {
|
|||||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
]);
|
||||||
|
let y = vec![
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
];
|
||||||
|
|
||||||
let forest = RandomForestRegressor::fit(&x, &y, Default::default());
|
let forest = RandomForestRegressor::fit(&x, &y, Default::default());
|
||||||
|
|
||||||
let deserialized_forest: RandomForestRegressor<f64> = bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
let deserialized_forest: RandomForestRegressor<f64> =
|
||||||
|
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(forest, deserialized_forest);
|
assert_eq!(forest, deserialized_forest);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
+9
-9
@@ -1,12 +1,12 @@
|
|||||||
pub mod linear;
|
|
||||||
pub mod neighbors;
|
|
||||||
pub mod ensemble;
|
|
||||||
pub mod tree;
|
|
||||||
pub mod cluster;
|
|
||||||
pub mod decomposition;
|
|
||||||
pub mod linalg;
|
|
||||||
pub mod math;
|
|
||||||
pub mod algorithm;
|
pub mod algorithm;
|
||||||
|
pub mod cluster;
|
||||||
pub mod common;
|
pub mod common;
|
||||||
pub mod optimization;
|
pub mod decomposition;
|
||||||
|
pub mod ensemble;
|
||||||
|
pub mod linalg;
|
||||||
|
pub mod linear;
|
||||||
|
pub mod math;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
|
pub mod neighbors;
|
||||||
|
pub mod optimization;
|
||||||
|
pub mod tree;
|
||||||
|
|||||||
+39
-37
@@ -1,29 +1,24 @@
|
|||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
|
|
||||||
use num::complex::Complex;
|
|
||||||
use crate::linalg::BaseMatrix;
|
use crate::linalg::BaseMatrix;
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
use num::complex::Complex;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct EVD<T: FloatExt, M: BaseMatrix<T>> {
|
pub struct EVD<T: FloatExt, M: BaseMatrix<T>> {
|
||||||
pub d: Vec<T>,
|
pub d: Vec<T>,
|
||||||
pub e: Vec<T>,
|
pub e: Vec<T>,
|
||||||
pub V: M
|
pub V: M,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: BaseMatrix<T>> EVD<T, M> {
|
impl<T: FloatExt, M: BaseMatrix<T>> EVD<T, M> {
|
||||||
pub fn new(V: M, d: Vec<T>, e: Vec<T>) -> EVD<T, M> {
|
pub fn new(V: M, d: Vec<T>, e: Vec<T>) -> EVD<T, M> {
|
||||||
EVD {
|
EVD { d: d, e: e, V: V }
|
||||||
d: d,
|
|
||||||
e: e,
|
|
||||||
V: V
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait EVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
pub trait EVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||||
|
|
||||||
fn evd(&self, symmetric: bool) -> EVD<T, Self> {
|
fn evd(&self, symmetric: bool) -> EVD<T, Self> {
|
||||||
self.clone().evd_mut(symmetric)
|
self.clone().evd_mut(symmetric)
|
||||||
}
|
}
|
||||||
@@ -45,7 +40,6 @@ pub trait EVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
tred2(&mut V, &mut d, &mut e);
|
tred2(&mut V, &mut d, &mut e);
|
||||||
// Diagonalize.
|
// Diagonalize.
|
||||||
tql2(&mut V, &mut d, &mut e);
|
tql2(&mut V, &mut d, &mut e);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
let scale = balance(&mut self);
|
let scale = balance(&mut self);
|
||||||
|
|
||||||
@@ -60,16 +54,11 @@ pub trait EVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
sort(&mut d, &mut e, &mut V);
|
sort(&mut d, &mut e, &mut V);
|
||||||
}
|
}
|
||||||
|
|
||||||
EVD {
|
EVD { V: V, d: d, e: e }
|
||||||
V: V,
|
|
||||||
d: d,
|
|
||||||
e: e
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tred2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
|
fn tred2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
|
||||||
|
|
||||||
let (n, _) = V.shape();
|
let (n, _) = V.shape();
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
d[i] = V.get(n - 1, i);
|
d[i] = V.get(n - 1, i);
|
||||||
@@ -527,7 +516,8 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let u = A.get(m, m - 1).abs() * (q.abs() + r.abs());
|
let u = A.get(m, m - 1).abs() * (q.abs() + r.abs());
|
||||||
let v = p.abs() * (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs());
|
let v = p.abs()
|
||||||
|
* (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs());
|
||||||
if u <= T::epsilon() * v {
|
if u <= T::epsilon() * v {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -609,7 +599,7 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
|
|||||||
if l + 1 >= nn {
|
if l + 1 >= nn {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if anorm != T::zero() {
|
if anorm != T::zero() {
|
||||||
@@ -672,7 +662,8 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
|
|||||||
A.set(na, na, q / A.get(nn, na));
|
A.set(na, na, q / A.get(nn, na));
|
||||||
A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na));
|
A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na));
|
||||||
} else {
|
} else {
|
||||||
let temp = Complex::new(T::zero(), -A.get(na, nn)) / Complex::new(A.get(na, na) - p, q);
|
let temp = Complex::new(T::zero(), -A.get(na, nn))
|
||||||
|
/ Complex::new(A.get(na, na) - p, q);
|
||||||
A.set(na, na, temp.re);
|
A.set(na, na, temp.re);
|
||||||
A.set(na, nn, temp.im);
|
A.set(na, nn, temp.im);
|
||||||
}
|
}
|
||||||
@@ -700,19 +691,34 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
|
|||||||
} else {
|
} else {
|
||||||
let x = A.get(i, i + 1);
|
let x = A.get(i, i + 1);
|
||||||
let y = A.get(i + 1, i);
|
let y = A.get(i + 1, i);
|
||||||
let mut vr = (d[i] - p).powf(T::two()) + (e[i]).powf(T::two()) - q * q;
|
let mut vr =
|
||||||
|
(d[i] - p).powf(T::two()) + (e[i]).powf(T::two()) - q * q;
|
||||||
let vi = T::two() * q * (d[i] - p);
|
let vi = T::two() * q * (d[i] - p);
|
||||||
if vr == T::zero() && vi == T::zero() {
|
if vr == T::zero() && vi == T::zero() {
|
||||||
vr = T::epsilon() * anorm * (w.abs() + q.abs() + x.abs() + y.abs() + z.abs());
|
vr = T::epsilon()
|
||||||
|
* anorm
|
||||||
|
* (w.abs() + q.abs() + x.abs() + y.abs() + z.abs());
|
||||||
}
|
}
|
||||||
let temp = Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra) / Complex::new(vr, vi);
|
let temp =
|
||||||
|
Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra)
|
||||||
|
/ Complex::new(vr, vi);
|
||||||
A.set(i, na, temp.re);
|
A.set(i, na, temp.re);
|
||||||
A.set(i, nn, temp.im);
|
A.set(i, nn, temp.im);
|
||||||
if x.abs() > z.abs() + q.abs() {
|
if x.abs() > z.abs() + q.abs() {
|
||||||
A.set(i + 1, na, (-ra - w * A.get(i, na) + q * A.get(i, nn)) / x);
|
A.set(
|
||||||
A.set(i + 1, nn, (-sa - w * A.get(i, nn) - q * A.get(i, na)) / x);
|
i + 1,
|
||||||
|
na,
|
||||||
|
(-ra - w * A.get(i, na) + q * A.get(i, nn)) / x,
|
||||||
|
);
|
||||||
|
A.set(
|
||||||
|
i + 1,
|
||||||
|
nn,
|
||||||
|
(-sa - w * A.get(i, nn) - q * A.get(i, na)) / x,
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
let temp = Complex::new(-r - y * A.get(i, na), -s - y * A.get(i, nn)) / Complex::new(z, q);
|
let temp =
|
||||||
|
Complex::new(-r - y * A.get(i, na), -s - y * A.get(i, nn))
|
||||||
|
/ Complex::new(z, q);
|
||||||
A.set(i + 1, na, temp.re);
|
A.set(i + 1, na, temp.re);
|
||||||
A.set(i + 1, nn, temp.im);
|
A.set(i + 1, nn, temp.im);
|
||||||
}
|
}
|
||||||
@@ -787,18 +793,18 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_symmetric() {
|
fn decompose_symmetric() {
|
||||||
|
|
||||||
let A = DenseMatrix::from_array(&[
|
let A = DenseMatrix::from_array(&[
|
||||||
&[0.9000, 0.4000, 0.7000],
|
&[0.9000, 0.4000, 0.7000],
|
||||||
&[0.4000, 0.5000, 0.3000],
|
&[0.4000, 0.5000, 0.3000],
|
||||||
&[0.7000, 0.3000, 0.8000]]);
|
&[0.7000, 0.3000, 0.8000],
|
||||||
|
]);
|
||||||
|
|
||||||
let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
|
let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
|
||||||
|
|
||||||
let eigen_vectors = DenseMatrix::from_array(&[
|
let eigen_vectors = DenseMatrix::from_array(&[
|
||||||
&[0.6881997, -0.07121225, 0.7220180],
|
&[0.6881997, -0.07121225, 0.7220180],
|
||||||
&[0.3700456, 0.89044952, -0.2648886],
|
&[0.3700456, 0.89044952, -0.2648886],
|
||||||
&[0.6240573, -0.44947578, -0.6391588]
|
&[0.6240573, -0.44947578, -0.6391588],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let evd = A.evd(true);
|
let evd = A.evd(true);
|
||||||
@@ -810,23 +816,22 @@ mod tests {
|
|||||||
for i in 0..eigen_values.len() {
|
for i in 0..eigen_values.len() {
|
||||||
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_asymmetric() {
|
fn decompose_asymmetric() {
|
||||||
|
|
||||||
let A = DenseMatrix::from_array(&[
|
let A = DenseMatrix::from_array(&[
|
||||||
&[0.9000, 0.4000, 0.7000],
|
&[0.9000, 0.4000, 0.7000],
|
||||||
&[0.4000, 0.5000, 0.3000],
|
&[0.4000, 0.5000, 0.3000],
|
||||||
&[0.8000, 0.3000, 0.8000]]);
|
&[0.8000, 0.3000, 0.8000],
|
||||||
|
]);
|
||||||
|
|
||||||
let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735];
|
let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735];
|
||||||
|
|
||||||
let eigen_vectors = DenseMatrix::from_array(&[
|
let eigen_vectors = DenseMatrix::from_array(&[
|
||||||
&[0.7178958, 0.05322098, 0.6812010],
|
&[0.7178958, 0.05322098, 0.6812010],
|
||||||
&[0.3837711, -0.84702111, -0.1494582],
|
&[0.3837711, -0.84702111, -0.1494582],
|
||||||
&[0.6952105, 0.43984484, -0.7036135]
|
&[0.6952105, 0.43984484, -0.7036135],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let evd = A.evd(false);
|
let evd = A.evd(false);
|
||||||
@@ -838,17 +843,16 @@ mod tests {
|
|||||||
for i in 0..eigen_values.len() {
|
for i in 0..eigen_values.len() {
|
||||||
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_complex() {
|
fn decompose_complex() {
|
||||||
|
|
||||||
let A = DenseMatrix::from_array(&[
|
let A = DenseMatrix::from_array(&[
|
||||||
&[3.0, -2.0, 1.0, 1.0],
|
&[3.0, -2.0, 1.0, 1.0],
|
||||||
&[4.0, -1.0, 1.0, 1.0],
|
&[4.0, -1.0, 1.0, 1.0],
|
||||||
&[1.0, 1.0, 3.0, -2.0],
|
&[1.0, 1.0, 3.0, -2.0],
|
||||||
&[1.0, 1.0, 4.0, -1.0]]);
|
&[1.0, 1.0, 4.0, -1.0],
|
||||||
|
]);
|
||||||
|
|
||||||
let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0];
|
let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0];
|
||||||
let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361];
|
let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361];
|
||||||
@@ -857,7 +861,7 @@ mod tests {
|
|||||||
&[-0.9159, -0.1378, 0.3816, -0.0806],
|
&[-0.9159, -0.1378, 0.3816, -0.0806],
|
||||||
&[-0.6707, 0.1059, 0.901, 0.6289],
|
&[-0.6707, 0.1059, 0.901, 0.6289],
|
||||||
&[0.9159, -0.1378, 0.3816, 0.0806],
|
&[0.9159, -0.1378, 0.3816, 0.0806],
|
||||||
&[0.6707, 0.1059, 0.901, -0.6289]
|
&[0.6707, 0.1059, 0.901, -0.6289],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let evd = A.evd(false);
|
let evd = A.evd(false);
|
||||||
@@ -869,7 +873,5 @@ mod tests {
|
|||||||
for i in 0..eigen_values_e.len() {
|
for i in 0..eigen_values_e.len() {
|
||||||
assert!((eigen_values_e[i] - evd.e[i]).abs() < 1e-4);
|
assert!((eigen_values_e[i] - evd.e[i]).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
+11
-17
@@ -3,8 +3,8 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::BaseMatrix;
|
use crate::linalg::BaseMatrix;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct LU<T: FloatExt, M: BaseMatrix<T>> {
|
pub struct LU<T: FloatExt, M: BaseMatrix<T>> {
|
||||||
@@ -12,12 +12,11 @@ pub struct LU<T: FloatExt, M: BaseMatrix<T>> {
|
|||||||
pivot: Vec<usize>,
|
pivot: Vec<usize>,
|
||||||
pivot_sign: i8,
|
pivot_sign: i8,
|
||||||
singular: bool,
|
singular: bool,
|
||||||
phantom: PhantomData<T>
|
phantom: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
||||||
pub fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> {
|
pub fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> {
|
||||||
|
|
||||||
let (_, n) = LU.shape();
|
let (_, n) = LU.shape();
|
||||||
|
|
||||||
let mut singular = false;
|
let mut singular = false;
|
||||||
@@ -33,7 +32,7 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
|||||||
pivot: pivot,
|
pivot: pivot,
|
||||||
pivot_sign: pivot_sign,
|
pivot_sign: pivot_sign,
|
||||||
singular: singular,
|
singular: singular,
|
||||||
phantom: PhantomData
|
phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,7 +105,10 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
|||||||
let (b_m, b_n) = b.shape();
|
let (b_m, b_n) = b.shape();
|
||||||
|
|
||||||
if b_m != m {
|
if b_m != m {
|
||||||
panic!("Row dimensions do not agree: A is {} x {}, but B is {} x {}", m, n, b_m, b_n);
|
panic!(
|
||||||
|
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
|
||||||
|
m, n, b_m, b_n
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.singular {
|
if self.singular {
|
||||||
@@ -148,19 +150,15 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b
|
b
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||||
|
|
||||||
fn lu(&self) -> LU<T, Self> {
|
fn lu(&self) -> LU<T, Self> {
|
||||||
self.clone().lu_mut()
|
self.clone().lu_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lu_mut(mut self) -> LU<T, Self> {
|
fn lu_mut(mut self) -> LU<T, Self> {
|
||||||
|
|
||||||
let (m, n) = self.shape();
|
let (m, n) = self.shape();
|
||||||
|
|
||||||
let mut piv = vec![0; m];
|
let mut piv = vec![0; m];
|
||||||
@@ -172,7 +170,6 @@ pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
let mut LUcolj = vec![T::zero(); m];
|
let mut LUcolj = vec![T::zero(); m];
|
||||||
|
|
||||||
for j in 0..n {
|
for j in 0..n {
|
||||||
|
|
||||||
for i in 0..m {
|
for i in 0..m {
|
||||||
LUcolj[i] = self.get(i, j);
|
LUcolj[i] = self.get(i, j);
|
||||||
}
|
}
|
||||||
@@ -214,13 +211,10 @@ pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
LU::new(self, piv, pivsign)
|
LU::new(self, piv, pivsign)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lu_solve_mut(self, b: Self) -> Self {
|
fn lu_solve_mut(self, b: Self) -> Self {
|
||||||
|
|
||||||
self.lu_mut().solve(b)
|
self.lu_mut().solve(b)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,11 +225,11 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose() {
|
fn decompose() {
|
||||||
|
|
||||||
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||||
let expected_L = DenseMatrix::from_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]);
|
let expected_L = DenseMatrix::from_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]);
|
||||||
let expected_U = DenseMatrix::from_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
|
let expected_U = DenseMatrix::from_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
|
||||||
let expected_pivot = DenseMatrix::from_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
|
let expected_pivot =
|
||||||
|
DenseMatrix::from_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
|
||||||
let lu = a.lu();
|
let lu = a.lu();
|
||||||
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
|
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
|
||||||
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
||||||
@@ -244,9 +238,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn inverse() {
|
fn inverse() {
|
||||||
|
|
||||||
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||||
let expected = DenseMatrix::from_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
|
let expected =
|
||||||
|
DenseMatrix::from_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
|
||||||
let a_inv = a.lu().inverse();
|
let a_inv = a.lu().inverse();
|
||||||
println!("{}", a_inv);
|
println!("{}", a_inv);
|
||||||
assert!(a_inv.approximate_eq(&expected, 1e-4));
|
assert!(a_inv.approximate_eq(&expected, 1e-4));
|
||||||
|
|||||||
+19
-15
@@ -1,23 +1,22 @@
|
|||||||
pub mod naive;
|
|
||||||
pub mod qr;
|
|
||||||
pub mod svd;
|
|
||||||
pub mod evd;
|
pub mod evd;
|
||||||
pub mod lu;
|
pub mod lu;
|
||||||
pub mod ndarray_bindings;
|
pub mod naive;
|
||||||
pub mod nalgebra_bindings;
|
pub mod nalgebra_bindings;
|
||||||
|
pub mod ndarray_bindings;
|
||||||
|
pub mod qr;
|
||||||
|
pub mod svd;
|
||||||
|
|
||||||
use std::ops::Range;
|
|
||||||
use std::fmt::{Debug, Display};
|
use std::fmt::{Debug, Display};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
use std::ops::Range;
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use svd::SVDDecomposableMatrix;
|
|
||||||
use evd::EVDDecomposableMatrix;
|
use evd::EVDDecomposableMatrix;
|
||||||
use qr::QRDecomposableMatrix;
|
|
||||||
use lu::LUDecomposableMatrix;
|
use lu::LUDecomposableMatrix;
|
||||||
|
use qr::QRDecomposableMatrix;
|
||||||
|
use svd::SVDDecomposableMatrix;
|
||||||
|
|
||||||
pub trait BaseVector<T: FloatExt>: Clone + Debug {
|
pub trait BaseVector<T: FloatExt>: Clone + Debug {
|
||||||
|
|
||||||
fn get(&self, i: usize) -> T;
|
fn get(&self, i: usize) -> T;
|
||||||
|
|
||||||
fn set(&mut self, i: usize, x: T);
|
fn set(&mut self, i: usize, x: T);
|
||||||
@@ -26,7 +25,6 @@ pub trait BaseVector<T: FloatExt>: Clone + Debug {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
|
pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
|
||||||
|
|
||||||
type RowVector: BaseVector<T> + Clone + Debug;
|
type RowVector: BaseVector<T> + Clone + Debug;
|
||||||
|
|
||||||
fn from_row_vector(vec: Self::RowVector) -> Self;
|
fn from_row_vector(vec: Self::RowVector) -> Self;
|
||||||
@@ -186,17 +184,25 @@ pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
|
|||||||
fn unique(&self) -> Vec<T>;
|
fn unique(&self) -> Vec<T>;
|
||||||
|
|
||||||
fn cov(&self) -> Self;
|
fn cov(&self) -> Self;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Matrix<T: FloatExt>: BaseMatrix<T> + SVDDecomposableMatrix<T> + EVDDecomposableMatrix<T> + QRDecomposableMatrix<T> + LUDecomposableMatrix<T> + PartialEq + Display {}
|
pub trait Matrix<T: FloatExt>:
|
||||||
|
BaseMatrix<T>
|
||||||
|
+ SVDDecomposableMatrix<T>
|
||||||
|
+ EVDDecomposableMatrix<T>
|
||||||
|
+ QRDecomposableMatrix<T>
|
||||||
|
+ LUDecomposableMatrix<T>
|
||||||
|
+ PartialEq
|
||||||
|
+ Display
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
pub fn row_iter<F: FloatExt, M: BaseMatrix<F>>(m: &M) -> RowIter<F, M> {
|
pub fn row_iter<F: FloatExt, M: BaseMatrix<F>>(m: &M) -> RowIter<F, M> {
|
||||||
RowIter {
|
RowIter {
|
||||||
m: m,
|
m: m,
|
||||||
pos: 0,
|
pos: 0,
|
||||||
max_pos: m.shape().0,
|
max_pos: m.shape().0,
|
||||||
phantom: PhantomData
|
phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,11 +210,10 @@ pub struct RowIter<'a, T: FloatExt, M: BaseMatrix<T>> {
|
|||||||
m: &'a M,
|
m: &'a M,
|
||||||
pos: usize,
|
pos: usize,
|
||||||
max_pos: usize,
|
max_pos: usize,
|
||||||
phantom: PhantomData<&'a T>
|
phantom: PhantomData<&'a T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: FloatExt, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
|
impl<'a, T: FloatExt, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
|
||||||
|
|
||||||
type Item = Vec<T>;
|
type Item = Vec<T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Vec<T>> {
|
fn next(&mut self) -> Option<Vec<T>> {
|
||||||
@@ -221,5 +226,4 @@ impl<'a, T: FloatExt, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
|
|||||||
self.pos += 1;
|
self.pos += 1;
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
+124
-138
@@ -1,19 +1,19 @@
|
|||||||
extern crate num;
|
extern crate num;
|
||||||
use std::ops::Range;
|
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
use std::ops::Range;
|
||||||
|
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
|
||||||
use serde::ser::{Serializer, SerializeStruct};
|
use serde::ser::{SerializeStruct, Serializer};
|
||||||
use serde::de::{Deserializer, Visitor, SeqAccess, MapAccess};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||||
|
use crate::linalg::lu::LUDecomposableMatrix;
|
||||||
|
use crate::linalg::qr::QRDecomposableMatrix;
|
||||||
|
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
pub use crate::linalg::{BaseMatrix, BaseVector};
|
pub use crate::linalg::{BaseMatrix, BaseVector};
|
||||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
|
||||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
|
||||||
use crate::linalg::qr::QRDecomposableMatrix;
|
|
||||||
use crate::linalg::lu::LUDecomposableMatrix;
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
impl<T: FloatExt> BaseVector<T> for Vec<T> {
|
impl<T: FloatExt> BaseVector<T> for Vec<T> {
|
||||||
@@ -31,30 +31,32 @@ impl<T: FloatExt> BaseVector<T> for Vec<T> {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct DenseMatrix<T: FloatExt> {
|
pub struct DenseMatrix<T: FloatExt> {
|
||||||
|
|
||||||
ncols: usize,
|
ncols: usize,
|
||||||
nrows: usize,
|
nrows: usize,
|
||||||
values: Vec<T>
|
values: Vec<T>,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> fmt::Display for DenseMatrix<T> {
|
impl<T: FloatExt> fmt::Display for DenseMatrix<T> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
let mut rows: Vec<Vec<f64>> = Vec::new();
|
let mut rows: Vec<Vec<f64>> = Vec::new();
|
||||||
for r in 0..self.nrows {
|
for r in 0..self.nrows {
|
||||||
rows.push(self.get_row_as_vec(r).iter().map(|x| (x.to_f64().unwrap() * 1e4).round() / 1e4 ).collect());
|
rows.push(
|
||||||
|
self.get_row_as_vec(r)
|
||||||
|
.iter()
|
||||||
|
.map(|x| (x.to_f64().unwrap() * 1e4).round() / 1e4)
|
||||||
|
.collect(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
write!(f, "{:?}", rows)
|
write!(f, "{:?}", rows)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> DenseMatrix<T> {
|
impl<T: FloatExt> DenseMatrix<T> {
|
||||||
|
|
||||||
fn new(nrows: usize, ncols: usize, values: Vec<T>) -> Self {
|
fn new(nrows: usize, ncols: usize, values: Vec<T>) -> Self {
|
||||||
DenseMatrix {
|
DenseMatrix {
|
||||||
ncols: ncols,
|
ncols: ncols,
|
||||||
nrows: nrows,
|
nrows: nrows,
|
||||||
values: values
|
values: values,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,11 +66,14 @@ impl<T: FloatExt> DenseMatrix<T> {
|
|||||||
|
|
||||||
pub fn from_vec(values: &Vec<Vec<T>>) -> DenseMatrix<T> {
|
pub fn from_vec(values: &Vec<Vec<T>>) -> DenseMatrix<T> {
|
||||||
let nrows = values.len();
|
let nrows = values.len();
|
||||||
let ncols = values.first().unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector")).len();
|
let ncols = values
|
||||||
|
.first()
|
||||||
|
.unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector"))
|
||||||
|
.len();
|
||||||
let mut m = DenseMatrix {
|
let mut m = DenseMatrix {
|
||||||
ncols: ncols,
|
ncols: ncols,
|
||||||
nrows: nrows,
|
nrows: nrows,
|
||||||
values: vec![T::zero(); ncols*nrows]
|
values: vec![T::zero(); ncols * nrows],
|
||||||
};
|
};
|
||||||
for row in 0..nrows {
|
for row in 0..nrows {
|
||||||
for col in 0..ncols {
|
for col in 0..ncols {
|
||||||
@@ -86,7 +91,7 @@ impl<T: FloatExt> DenseMatrix<T> {
|
|||||||
DenseMatrix {
|
DenseMatrix {
|
||||||
ncols: values.len(),
|
ncols: values.len(),
|
||||||
nrows: 1,
|
nrows: 1,
|
||||||
values: values
|
values: values,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,7 +108,6 @@ impl<T: FloatExt> DenseMatrix<T> {
|
|||||||
pub fn get_raw_values(&self) -> &Vec<T> {
|
pub fn get_raw_values(&self) -> &Vec<T> {
|
||||||
&self.values
|
&self.values
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
||||||
@@ -111,13 +115,16 @@ impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for Dens
|
|||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
#[serde(field_identifier, rename_all = "lowercase")]
|
#[serde(field_identifier, rename_all = "lowercase")]
|
||||||
enum Field { NRows, NCols, Values }
|
enum Field {
|
||||||
|
NRows,
|
||||||
|
NCols,
|
||||||
|
Values,
|
||||||
|
}
|
||||||
|
|
||||||
struct DenseMatrixVisitor<T: FloatExt + fmt::Debug> {
|
struct DenseMatrixVisitor<T: FloatExt + fmt::Debug> {
|
||||||
t: PhantomData<T>
|
t: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: FloatExt + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor<T> {
|
impl<'a, T: FloatExt + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor<T> {
|
||||||
@@ -131,11 +138,14 @@ impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for Dens
|
|||||||
where
|
where
|
||||||
V: SeqAccess<'a>,
|
V: SeqAccess<'a>,
|
||||||
{
|
{
|
||||||
let nrows = seq.next_element()?
|
let nrows = seq
|
||||||
|
.next_element()?
|
||||||
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
|
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
|
||||||
let ncols = seq.next_element()?
|
let ncols = seq
|
||||||
|
.next_element()?
|
||||||
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
|
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
|
||||||
let values = seq.next_element()?
|
let values = seq
|
||||||
|
.next_element()?
|
||||||
.ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
|
.ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
|
||||||
Ok(DenseMatrix::new(nrows, ncols, values))
|
Ok(DenseMatrix::new(nrows, ncols, values))
|
||||||
}
|
}
|
||||||
@@ -177,16 +187,19 @@ impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for Dens
|
|||||||
}
|
}
|
||||||
|
|
||||||
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"];
|
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"];
|
||||||
deserializer.deserialize_struct("DenseMatrix", FIELDS, DenseMatrixVisitor {
|
deserializer.deserialize_struct(
|
||||||
t: PhantomData
|
"DenseMatrix",
|
||||||
})
|
FIELDS,
|
||||||
|
DenseMatrixVisitor { t: PhantomData },
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
impl<T: FloatExt + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where
|
where
|
||||||
S: Serializer {
|
S: Serializer,
|
||||||
|
{
|
||||||
let (nrows, ncols) = self.shape();
|
let (nrows, ncols) = self.shape();
|
||||||
let mut state = serializer.serialize_struct("DenseMatrix", 3)?;
|
let mut state = serializer.serialize_struct("DenseMatrix", 3)?;
|
||||||
state.serialize_field("nrows", &nrows)?;
|
state.serialize_field("nrows", &nrows)?;
|
||||||
@@ -209,7 +222,7 @@ impl<T: FloatExt> Matrix<T> for DenseMatrix<T> {}
|
|||||||
impl<T: FloatExt> PartialEq for DenseMatrix<T> {
|
impl<T: FloatExt> PartialEq for DenseMatrix<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.ncols != other.ncols || self.nrows != other.nrows {
|
if self.ncols != other.ncols || self.nrows != other.nrows {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
let len = self.values.len();
|
let len = self.values.len();
|
||||||
@@ -236,7 +249,6 @@ impl<T: FloatExt> Into<Vec<T>> for DenseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
||||||
|
|
||||||
type RowVector = Vec<T>;
|
type RowVector = Vec<T>;
|
||||||
|
|
||||||
fn from_row_vector(vec: Self::RowVector) -> Self {
|
fn from_row_vector(vec: Self::RowVector) -> Self {
|
||||||
@@ -249,7 +261,10 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
|
|
||||||
fn get(&self, row: usize, col: usize) -> T {
|
fn get(&self, row: usize, col: usize) -> T {
|
||||||
if row >= self.nrows || col >= self.ncols {
|
if row >= self.nrows || col >= self.ncols {
|
||||||
panic!("Invalid index ({},{}) for {}x{} matrix", row, col, self.nrows, self.ncols);
|
panic!(
|
||||||
|
"Invalid index ({},{}) for {}x{} matrix",
|
||||||
|
row, col, self.nrows, self.ncols
|
||||||
|
);
|
||||||
}
|
}
|
||||||
self.values[col * self.nrows + row]
|
self.values[col * self.nrows + row]
|
||||||
}
|
}
|
||||||
@@ -343,7 +358,6 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dot(&self, other: &Self) -> Self {
|
fn dot(&self, other: &Self) -> Self {
|
||||||
|
|
||||||
if self.ncols != other.nrows {
|
if self.ncols != other.nrows {
|
||||||
panic!("Number of rows of A should equal number of columns of B");
|
panic!("Number of rows of A should equal number of columns of B");
|
||||||
}
|
}
|
||||||
@@ -380,7 +394,6 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self {
|
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self {
|
||||||
|
|
||||||
let ncols = cols.len();
|
let ncols = cols.len();
|
||||||
let nrows = rows.len();
|
let nrows = rows.len();
|
||||||
|
|
||||||
@@ -397,13 +410,13 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
|
|
||||||
fn approximate_eq(&self, other: &Self, error: T) -> bool {
|
fn approximate_eq(&self, other: &Self, error: T) -> bool {
|
||||||
if self.ncols != other.ncols || self.nrows != other.nrows {
|
if self.ncols != other.ncols || self.nrows != other.nrows {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
for c in 0..self.ncols {
|
for c in 0..self.ncols {
|
||||||
for r in 0..self.nrows {
|
for r in 0..self.nrows {
|
||||||
if (self.get(r, c) - other.get(r, c)).abs() > error {
|
if (self.get(r, c) - other.get(r, c)).abs() > error {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -487,7 +500,7 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
let mut m = DenseMatrix {
|
let mut m = DenseMatrix {
|
||||||
ncols: self.nrows,
|
ncols: self.nrows,
|
||||||
nrows: self.ncols,
|
nrows: self.ncols,
|
||||||
values: vec![T::zero(); self.ncols * self.nrows]
|
values: vec![T::zero(); self.ncols * self.nrows],
|
||||||
};
|
};
|
||||||
for c in 0..self.ncols {
|
for c in 0..self.ncols {
|
||||||
for r in 0..self.nrows {
|
for r in 0..self.nrows {
|
||||||
@@ -495,17 +508,14 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
m
|
m
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand(nrows: usize, ncols: usize) -> Self {
|
fn rand(nrows: usize, ncols: usize) -> Self {
|
||||||
let values: Vec<T> = (0..nrows*ncols).map(|_| {
|
let values: Vec<T> = (0..nrows * ncols).map(|_| T::rand()).collect();
|
||||||
T::rand()
|
|
||||||
}).collect();
|
|
||||||
DenseMatrix {
|
DenseMatrix {
|
||||||
ncols: ncols,
|
ncols: ncols,
|
||||||
nrows: nrows,
|
nrows: nrows,
|
||||||
values: values
|
values: values,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -520,13 +530,17 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn norm(&self, p: T) -> T {
|
fn norm(&self, p: T) -> T {
|
||||||
|
|
||||||
if p.is_infinite() && p.is_sign_positive() {
|
if p.is_infinite() && p.is_sign_positive() {
|
||||||
self.values.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b))
|
self.values
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.abs())
|
||||||
|
.fold(T::neg_infinity(), |a, b| a.max(b))
|
||||||
} else if p.is_infinite() && p.is_sign_negative() {
|
} else if p.is_infinite() && p.is_sign_negative() {
|
||||||
self.values.iter().map(|x| x.abs()).fold(T::infinity(), |a, b| a.min(b))
|
self.values
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.abs())
|
||||||
|
.fold(T::infinity(), |a, b| a.min(b))
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
let mut norm = T::zero();
|
let mut norm = T::zero();
|
||||||
|
|
||||||
for xi in self.values.iter() {
|
for xi in self.values.iter() {
|
||||||
@@ -589,7 +603,10 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
|
|
||||||
fn reshape(&self, nrows: usize, ncols: usize) -> Self {
|
fn reshape(&self, nrows: usize, ncols: usize) -> Self {
|
||||||
if self.nrows * self.ncols != nrows * ncols {
|
if self.nrows * self.ncols != nrows * ncols {
|
||||||
panic!("Can't reshape {}x{} matrix into {}x{}.", self.nrows, self.ncols, nrows, ncols);
|
panic!(
|
||||||
|
"Can't reshape {}x{} matrix into {}x{}.",
|
||||||
|
self.nrows, self.ncols, nrows, ncols
|
||||||
|
);
|
||||||
}
|
}
|
||||||
let mut dst = DenseMatrix::zeros(nrows, ncols);
|
let mut dst = DenseMatrix::zeros(nrows, ncols);
|
||||||
let mut dst_r = 0;
|
let mut dst_r = 0;
|
||||||
@@ -609,9 +626,11 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn copy_from(&mut self, other: &Self) {
|
fn copy_from(&mut self, other: &Self) {
|
||||||
|
|
||||||
if self.nrows != other.nrows || self.ncols != other.ncols {
|
if self.nrows != other.nrows || self.ncols != other.ncols {
|
||||||
panic!("Can't copy {}x{} matrix into {}x{}.", self.nrows, self.ncols, other.nrows, other.ncols);
|
panic!(
|
||||||
|
"Can't copy {}x{} matrix into {}x{}.",
|
||||||
|
self.nrows, self.ncols, other.nrows, other.ncols
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
for i in 0..self.values.len() {
|
for i in 0..self.values.len() {
|
||||||
@@ -632,7 +651,6 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
max_diff = max_diff.max((self.values[i] - other.values[i]).abs());
|
max_diff = max_diff.max((self.values[i] - other.values[i]).abs());
|
||||||
}
|
}
|
||||||
max_diff
|
max_diff
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sum(&self) -> T {
|
fn sum(&self) -> T {
|
||||||
@@ -644,7 +662,11 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn softmax_mut(&mut self) {
|
fn softmax_mut(&mut self) {
|
||||||
let max = self.values.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b));
|
let max = self
|
||||||
|
.values
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.abs())
|
||||||
|
.fold(T::neg_infinity(), |a, b| a.max(b));
|
||||||
let mut z = T::zero();
|
let mut z = T::zero();
|
||||||
for r in 0..self.nrows {
|
for r in 0..self.nrows {
|
||||||
for c in 0..self.ncols {
|
for c in 0..self.ncols {
|
||||||
@@ -668,7 +690,6 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn argmax(&self) -> Vec<usize> {
|
fn argmax(&self) -> Vec<usize> {
|
||||||
|
|
||||||
let mut res = vec![0usize; self.nrows];
|
let mut res = vec![0usize; self.nrows];
|
||||||
|
|
||||||
for r in 0..self.nrows {
|
for r in 0..self.nrows {
|
||||||
@@ -685,7 +706,6 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res
|
res
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unique(&self) -> Vec<T> {
|
fn unique(&self) -> Vec<T> {
|
||||||
@@ -696,7 +716,6 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn cov(&self) -> Self {
|
fn cov(&self) -> Self {
|
||||||
|
|
||||||
let (m, n) = self.shape();
|
let (m, n) = self.shape();
|
||||||
|
|
||||||
let mu = self.column_mean();
|
let mu = self.column_mean();
|
||||||
@@ -722,7 +741,6 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
|
|
||||||
cov
|
cov
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -731,109 +749,71 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn from_to_row_vec() {
|
fn from_to_row_vec() {
|
||||||
|
|
||||||
let vec = vec![1., 2., 3.];
|
let vec = vec![1., 2., 3.];
|
||||||
assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::new(1, 3, vec![1., 2., 3.]));
|
assert_eq!(
|
||||||
assert_eq!(DenseMatrix::from_row_vector(vec.clone()).to_row_vector(), vec![1., 2., 3.]);
|
DenseMatrix::from_row_vector(vec.clone()),
|
||||||
|
DenseMatrix::new(1, 3, vec![1., 2., 3.])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
DenseMatrix::from_row_vector(vec.clone()).to_row_vector(),
|
||||||
|
vec![1., 2., 3.]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn h_stack() {
|
fn h_stack() {
|
||||||
|
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||||
let a = DenseMatrix::from_array(
|
let b = DenseMatrix::from_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||||
&[
|
let expected = DenseMatrix::from_array(&[
|
||||||
&[1., 2., 3.],
|
|
||||||
&[4., 5., 6.],
|
|
||||||
&[7., 8., 9.]]);
|
|
||||||
let b = DenseMatrix::from_array(
|
|
||||||
&[
|
|
||||||
&[1., 2., 3.],
|
|
||||||
&[4., 5., 6.]]);
|
|
||||||
let expected = DenseMatrix::from_array(
|
|
||||||
&[
|
|
||||||
&[1., 2., 3.],
|
&[1., 2., 3.],
|
||||||
&[4., 5., 6.],
|
&[4., 5., 6.],
|
||||||
&[7., 8., 9.],
|
&[7., 8., 9.],
|
||||||
&[1., 2., 3.],
|
&[1., 2., 3.],
|
||||||
&[4., 5., 6.]]);
|
&[4., 5., 6.],
|
||||||
|
]);
|
||||||
let result = a.h_stack(&b);
|
let result = a.h_stack(&b);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn v_stack() {
|
fn v_stack() {
|
||||||
|
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||||
let a = DenseMatrix::from_array(
|
let b = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
|
||||||
&[
|
let expected = DenseMatrix::from_array(&[
|
||||||
&[1., 2., 3.],
|
|
||||||
&[4., 5., 6.],
|
|
||||||
&[7., 8., 9.]]);
|
|
||||||
let b = DenseMatrix::from_array(
|
|
||||||
&[
|
|
||||||
&[1., 2.],
|
|
||||||
&[3., 4.],
|
|
||||||
&[5., 6.]]);
|
|
||||||
let expected = DenseMatrix::from_array(
|
|
||||||
&[
|
|
||||||
&[1., 2., 3., 1., 2.],
|
&[1., 2., 3., 1., 2.],
|
||||||
&[4., 5., 6., 3., 4.],
|
&[4., 5., 6., 3., 4.],
|
||||||
&[7., 8., 9., 5., 6.]]);
|
&[7., 8., 9., 5., 6.],
|
||||||
|
]);
|
||||||
let result = a.v_stack(&b);
|
let result = a.v_stack(&b);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn dot() {
|
fn dot() {
|
||||||
|
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||||
let a = DenseMatrix::from_array(
|
let b = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
|
||||||
&[
|
let expected = DenseMatrix::from_array(&[&[22., 28.], &[49., 64.]]);
|
||||||
&[1., 2., 3.],
|
|
||||||
&[4., 5., 6.]]);
|
|
||||||
let b = DenseMatrix::from_array(
|
|
||||||
&[
|
|
||||||
&[1., 2.],
|
|
||||||
&[3., 4.],
|
|
||||||
&[5., 6.]]);
|
|
||||||
let expected = DenseMatrix::from_array(
|
|
||||||
&[
|
|
||||||
&[22., 28.],
|
|
||||||
&[49., 64.]]);
|
|
||||||
let result = a.dot(&b);
|
let result = a.dot(&b);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn slice() {
|
fn slice() {
|
||||||
|
let m = DenseMatrix::from_array(&[
|
||||||
let m = DenseMatrix::from_array(
|
|
||||||
&[
|
|
||||||
&[1., 2., 3., 1., 2.],
|
&[1., 2., 3., 1., 2.],
|
||||||
&[4., 5., 6., 3., 4.],
|
&[4., 5., 6., 3., 4.],
|
||||||
&[7., 8., 9., 5., 6.]]);
|
&[7., 8., 9., 5., 6.],
|
||||||
let expected = DenseMatrix::from_array(
|
]);
|
||||||
&[
|
let expected = DenseMatrix::from_array(&[&[2., 3.], &[5., 6.]]);
|
||||||
&[2., 3.],
|
|
||||||
&[5., 6.]]);
|
|
||||||
let result = m.slice(0..2, 1..3);
|
let result = m.slice(0..2, 1..3);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn approximate_eq() {
|
fn approximate_eq() {
|
||||||
let m = DenseMatrix::from_array(
|
let m = DenseMatrix::from_array(&[&[2., 3.], &[5., 6.]]);
|
||||||
&[
|
let m_eq = DenseMatrix::from_array(&[&[2.5, 3.0], &[5., 5.5]]);
|
||||||
&[2., 3.],
|
let m_neq = DenseMatrix::from_array(&[&[3.0, 3.0], &[5., 6.5]]);
|
||||||
&[5., 6.]]);
|
|
||||||
let m_eq = DenseMatrix::from_array(
|
|
||||||
&[
|
|
||||||
&[2.5, 3.0],
|
|
||||||
&[5., 5.5]]);
|
|
||||||
let m_neq = DenseMatrix::from_array(
|
|
||||||
&[
|
|
||||||
&[3.0, 3.0],
|
|
||||||
&[5., 6.5]]);
|
|
||||||
assert!(m.approximate_eq(&m_eq, 0.5));
|
assert!(m.approximate_eq(&m_eq, 0.5));
|
||||||
assert!(!m.approximate_eq(&m_neq, 0.5));
|
assert!(!m.approximate_eq(&m_neq, 0.5));
|
||||||
}
|
}
|
||||||
@@ -873,7 +853,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn norm() {
|
fn norm() {
|
||||||
|
|
||||||
let v = DenseMatrix::vector_from_array(&[3., -2., 6.]);
|
let v = DenseMatrix::vector_from_array(&[3., -2., 6.]);
|
||||||
assert_eq!(v.norm(1.), 11.);
|
assert_eq!(v.norm(1.), 11.);
|
||||||
assert_eq!(v.norm(2.), 7.);
|
assert_eq!(v.norm(2.), 7.);
|
||||||
@@ -883,7 +862,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn softmax_mut() {
|
fn softmax_mut() {
|
||||||
|
|
||||||
let mut prob: DenseMatrix<f64> = DenseMatrix::vector_from_array(&[1., 2., 3.]);
|
let mut prob: DenseMatrix<f64> = DenseMatrix::vector_from_array(&[1., 2., 3.]);
|
||||||
prob.softmax_mut();
|
prob.softmax_mut();
|
||||||
assert!((prob.get(0, 0) - 0.09).abs() < 0.01);
|
assert!((prob.get(0, 0) - 0.09).abs() < 0.01);
|
||||||
@@ -893,20 +871,14 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn col_mean() {
|
fn col_mean() {
|
||||||
let a = DenseMatrix::from_array(&[
|
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||||
&[1., 2., 3.],
|
|
||||||
&[4., 5., 6.],
|
|
||||||
&[7., 8., 9.]]);
|
|
||||||
let res = a.column_mean();
|
let res = a.column_mean();
|
||||||
assert_eq!(res, vec![4., 5., 6.]);
|
assert_eq!(res, vec![4., 5., 6.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn eye() {
|
fn eye() {
|
||||||
let a = DenseMatrix::from_array(&[
|
let a = DenseMatrix::from_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0., 0., 1.]]);
|
||||||
&[1., 0., 0.],
|
|
||||||
&[0., 1., 0.],
|
|
||||||
&[0., 0., 1.]]);
|
|
||||||
let res = DenseMatrix::eye(3);
|
let res = DenseMatrix::eye(3);
|
||||||
assert_eq!(res, a);
|
assert_eq!(res, a);
|
||||||
}
|
}
|
||||||
@@ -914,28 +886,42 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn to_from_json() {
|
fn to_from_json() {
|
||||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let deserialized_a: DenseMatrix<f64> = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
|
let deserialized_a: DenseMatrix<f64> =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
|
||||||
assert_eq!(a, deserialized_a);
|
assert_eq!(a, deserialized_a);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn to_from_bincode() {
|
fn to_from_bincode() {
|
||||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let deserialized_a: DenseMatrix<f64> = bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
|
let deserialized_a: DenseMatrix<f64> =
|
||||||
|
bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
|
||||||
assert_eq!(a, deserialized_a);
|
assert_eq!(a, deserialized_a);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn to_string() {
|
fn to_string() {
|
||||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
assert_eq!(format!("{}", a), "[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]");
|
assert_eq!(
|
||||||
|
format!("{}", a),
|
||||||
|
"[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cov() {
|
fn cov() {
|
||||||
let a = DenseMatrix::from_array(&[&[64.0, 580.0, 29.0], &[66.0, 570.0, 33.0], &[68.0, 590.0, 37.0], &[69.0, 660.0, 46.0], &[73.0, 600.0, 55.0]]);
|
let a = DenseMatrix::from_array(&[
|
||||||
let expected = DenseMatrix::from_array(&[&[11.5, 50.0, 34.75], &[50.0, 1250.0, 205.0], &[34.75, 205.0, 110.0]]);
|
&[64.0, 580.0, 29.0],
|
||||||
|
&[66.0, 570.0, 33.0],
|
||||||
|
&[68.0, 590.0, 37.0],
|
||||||
|
&[69.0, 660.0, 46.0],
|
||||||
|
&[73.0, 600.0, 55.0],
|
||||||
|
]);
|
||||||
|
let expected = DenseMatrix::from_array(&[
|
||||||
|
&[11.5, 50.0, 34.75],
|
||||||
|
&[50.0, 1250.0, 205.0],
|
||||||
|
&[34.75, 205.0, 110.0],
|
||||||
|
]);
|
||||||
assert_eq!(a.cov(), expected);
|
assert_eq!(a.cov(), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
use std::ops::{Range, AddAssign, SubAssign, MulAssign, DivAssign};
|
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
|
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
|
||||||
|
|
||||||
use nalgebra::{MatrixMN, DMatrix, Matrix, Scalar, Dynamic, U1, VecStorage};
|
use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, Scalar, VecStorage, U1};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::{BaseMatrix, BaseVector};
|
|
||||||
use crate::linalg::Matrix as SmartCoreMatrix;
|
|
||||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
|
||||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||||
use crate::linalg::qr::QRDecomposableMatrix;
|
|
||||||
use crate::linalg::lu::LUDecomposableMatrix;
|
use crate::linalg::lu::LUDecomposableMatrix;
|
||||||
|
use crate::linalg::qr::QRDecomposableMatrix;
|
||||||
|
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||||
|
use crate::linalg::Matrix as SmartCoreMatrix;
|
||||||
|
use crate::linalg::{BaseMatrix, BaseVector};
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
impl<T: FloatExt + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
|
impl<T: FloatExt + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
|
||||||
fn get(&self, i: usize) -> T {
|
fn get(&self, i: usize) -> T {
|
||||||
@@ -24,7 +24,8 @@ impl<T: FloatExt + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> BaseMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||||
|
BaseMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||||
{
|
{
|
||||||
type RowVector = MatrixMN<T, U1, Dynamic>;
|
type RowVector = MatrixMN<T, U1, Dynamic>;
|
||||||
|
|
||||||
@@ -171,9 +172,7 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn rand(nrows: usize, ncols: usize) -> Self {
|
fn rand(nrows: usize, ncols: usize) -> Self {
|
||||||
DMatrix::from_iterator(nrows, ncols, (0..nrows*ncols).map(|_| {
|
DMatrix::from_iterator(nrows, ncols, (0..nrows * ncols).map(|_| T::rand()))
|
||||||
T::rand()
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn norm2(&self) -> T {
|
fn norm2(&self) -> T {
|
||||||
@@ -200,7 +199,6 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
let mut norm = T::zero();
|
let mut norm = T::zero();
|
||||||
|
|
||||||
for xi in self.iter() {
|
for xi in self.iter() {
|
||||||
@@ -212,7 +210,6 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn column_mean(&self) -> Vec<T> {
|
fn column_mean(&self) -> Vec<T> {
|
||||||
|
|
||||||
let mut res = Vec::new();
|
let mut res = Vec::new();
|
||||||
|
|
||||||
for column in self.column_iter() {
|
for column in self.column_iter() {
|
||||||
@@ -282,7 +279,10 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn softmax_mut(&mut self) {
|
fn softmax_mut(&mut self) {
|
||||||
let max = self.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b));
|
let max = self
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.abs())
|
||||||
|
.fold(T::neg_infinity(), |a, b| a.max(b));
|
||||||
let mut z = T::zero();
|
let mut z = T::zero();
|
||||||
for r in 0..self.nrows() {
|
for r in 0..self.nrows() {
|
||||||
for c in 0..self.ncols() {
|
for c in 0..self.ncols() {
|
||||||
@@ -322,7 +322,6 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
|
|||||||
}
|
}
|
||||||
|
|
||||||
res
|
res
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unique(&self) -> Vec<T> {
|
fn unique(&self) -> Vec<T> {
|
||||||
@@ -335,35 +334,49 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
|
|||||||
fn cov(&self) -> Self {
|
fn cov(&self) -> Self {
|
||||||
panic!("Not implemented");
|
panic!("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> SVDDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> {}
|
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||||
|
SVDDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> EVDDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> {}
|
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||||
|
EVDDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> QRDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> {}
|
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||||
|
QRDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> LUDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> {}
|
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||||
|
LUDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> SmartCoreMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> {}
|
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||||
|
SmartCoreMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use nalgebra::{Matrix2x3, DMatrix, RowDVector};
|
use nalgebra::{DMatrix, Matrix2x3, RowDVector};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_len() {
|
fn vec_len() {
|
||||||
let v = RowDVector::from_vec(vec!(1., 2., 3.));
|
let v = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||||
assert_eq!(3, v.len());
|
assert_eq!(3, v.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn get_set_vector() {
|
fn get_set_vector() {
|
||||||
let mut v = RowDVector::from_vec(vec!(1., 2., 3., 4.));
|
let mut v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
|
||||||
|
|
||||||
let expected = RowDVector::from_vec(vec!(1., 5., 3., 4.));
|
let expected = RowDVector::from_vec(vec![1., 5., 3., 4.]);
|
||||||
|
|
||||||
v.set(1, 5.);
|
v.set(1, 5.);
|
||||||
|
|
||||||
@@ -373,14 +386,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn get_set_dynamic() {
|
fn get_set_dynamic() {
|
||||||
let mut m = DMatrix::from_row_slice(
|
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
2,
|
|
||||||
3,
|
|
||||||
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
|
||||||
);
|
|
||||||
|
|
||||||
let expected = Matrix2x3::new(1., 2., 3., 4.,
|
let expected = Matrix2x3::new(1., 2., 3., 4., 10., 6.);
|
||||||
10., 6.);
|
|
||||||
|
|
||||||
m.set(1, 1, 10.);
|
m.set(1, 1, 10.);
|
||||||
|
|
||||||
@@ -390,11 +398,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn zeros() {
|
fn zeros() {
|
||||||
let expected = DMatrix::from_row_slice(
|
let expected = DMatrix::from_row_slice(2, 2, &[0., 0., 0., 0.]);
|
||||||
2,
|
|
||||||
2,
|
|
||||||
&[0., 0., 0., 0.],
|
|
||||||
);
|
|
||||||
|
|
||||||
let m: DMatrix<f64> = BaseMatrix::zeros(2, 2);
|
let m: DMatrix<f64> = BaseMatrix::zeros(2, 2);
|
||||||
|
|
||||||
@@ -403,11 +407,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ones() {
|
fn ones() {
|
||||||
let expected = DMatrix::from_row_slice(
|
let expected = DMatrix::from_row_slice(2, 2, &[1., 1., 1., 1.]);
|
||||||
2,
|
|
||||||
2,
|
|
||||||
&[1., 1., 1., 1.],
|
|
||||||
);
|
|
||||||
|
|
||||||
let m: DMatrix<f64> = BaseMatrix::ones(2, 2);
|
let m: DMatrix<f64> = BaseMatrix::ones(2, 2);
|
||||||
|
|
||||||
@@ -432,17 +432,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn scalar_add_sub_mul_div() {
|
fn scalar_add_sub_mul_div() {
|
||||||
let mut m = DMatrix::from_row_slice(
|
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
2,
|
|
||||||
3,
|
|
||||||
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
|
||||||
);
|
|
||||||
|
|
||||||
let expected = DMatrix::from_row_slice(
|
let expected = DMatrix::from_row_slice(2, 3, &[0.6, 0.8, 1., 1.2, 1.4, 1.6]);
|
||||||
2,
|
|
||||||
3,
|
|
||||||
&[0.6, 0.8, 1., 1.2, 1.4, 1.6],
|
|
||||||
);
|
|
||||||
|
|
||||||
m.add_scalar_mut(3.0);
|
m.add_scalar_mut(3.0);
|
||||||
m.sub_scalar_mut(1.0);
|
m.sub_scalar_mut(1.0);
|
||||||
@@ -453,25 +445,13 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn add_sub_mul_div() {
|
fn add_sub_mul_div() {
|
||||||
let mut m = DMatrix::from_row_slice(
|
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
|
||||||
2,
|
|
||||||
2,
|
|
||||||
&[1.0, 2.0, 3.0, 4.0],
|
|
||||||
);
|
|
||||||
|
|
||||||
let a = DMatrix::from_row_slice(
|
let a = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
|
||||||
2,
|
|
||||||
2,
|
|
||||||
&[1.0, 2.0, 3.0, 4.0],
|
|
||||||
);
|
|
||||||
|
|
||||||
let b: DMatrix<f64> = BaseMatrix::fill(2, 2, 10.);
|
let b: DMatrix<f64> = BaseMatrix::fill(2, 2, 10.);
|
||||||
|
|
||||||
let expected = DMatrix::from_row_slice(
|
let expected = DMatrix::from_row_slice(2, 2, &[0.1, 0.6, 1.5, 2.8]);
|
||||||
2,
|
|
||||||
2,
|
|
||||||
&[0.1, 0.6, 1.5, 2.8],
|
|
||||||
);
|
|
||||||
|
|
||||||
m.add_mut(&a);
|
m.add_mut(&a);
|
||||||
m.mul_mut(&a);
|
m.mul_mut(&a);
|
||||||
@@ -483,7 +463,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn to_from_row_vector() {
|
fn to_from_row_vector() {
|
||||||
let v = RowDVector::from_vec(vec!(1., 2., 3., 4.));
|
let v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
|
||||||
let expected = v.clone();
|
let expected = v.clone();
|
||||||
let m: DMatrix<f64> = BaseMatrix::from_row_vector(v);
|
let m: DMatrix<f64> = BaseMatrix::from_row_vector(v);
|
||||||
assert_eq!(m.to_row_vector(), expected);
|
assert_eq!(m.to_row_vector(), expected);
|
||||||
@@ -491,11 +471,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn get_row_col_as_vec() {
|
fn get_row_col_as_vec() {
|
||||||
let m = DMatrix::from_row_slice(
|
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
|
||||||
3,
|
|
||||||
3,
|
|
||||||
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(m.get_row_as_vec(1), vec!(4., 5., 6.));
|
assert_eq!(m.get_row_as_vec(1), vec!(4., 5., 6.));
|
||||||
assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.));
|
assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.));
|
||||||
@@ -503,28 +479,16 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn to_raw_vector() {
|
fn to_raw_vector() {
|
||||||
let m = DMatrix::from_row_slice(
|
let m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
2,
|
|
||||||
3,
|
|
||||||
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(m.to_raw_vector(), vec!(1., 2., 3., 4., 5., 6.));
|
assert_eq!(m.to_raw_vector(), vec!(1., 2., 3., 4., 5., 6.));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn element_add_sub_mul_div() {
|
fn element_add_sub_mul_div() {
|
||||||
let mut m = DMatrix::from_row_slice(
|
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
|
||||||
2,
|
|
||||||
2,
|
|
||||||
&[1.0, 2.0, 3.0, 4.0],
|
|
||||||
);
|
|
||||||
|
|
||||||
let expected = DMatrix::from_row_slice(
|
let expected = DMatrix::from_row_slice(2, 2, &[4., 1., 6., 0.4]);
|
||||||
2,
|
|
||||||
2,
|
|
||||||
&[4., 1., 6., 0.4],
|
|
||||||
);
|
|
||||||
|
|
||||||
m.add_element_mut(0, 0, 3.0);
|
m.add_element_mut(0, 0, 3.0);
|
||||||
m.sub_element_mut(0, 1, 1.0);
|
m.sub_element_mut(0, 1, 1.0);
|
||||||
@@ -535,23 +499,21 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn vstack_hstack() {
|
fn vstack_hstack() {
|
||||||
|
|
||||||
let m1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
let m1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
||||||
let m2 = DMatrix::from_row_slice(2, 1, &[7., 8.]);
|
let m2 = DMatrix::from_row_slice(2, 1, &[7., 8.]);
|
||||||
|
|
||||||
let m3 = DMatrix::from_row_slice(1, 4, &[9., 10., 11., 12.]);
|
let m3 = DMatrix::from_row_slice(1, 4, &[9., 10., 11., 12.]);
|
||||||
|
|
||||||
let expected = DMatrix::from_row_slice(3, 4, &[1., 2., 3., 7., 4., 5., 6., 8., 9., 10., 11., 12.]);
|
let expected =
|
||||||
|
DMatrix::from_row_slice(3, 4, &[1., 2., 3., 7., 4., 5., 6., 8., 9., 10., 11., 12.]);
|
||||||
|
|
||||||
let result = m1.v_stack(&m2).h_stack(&m3);
|
let result = m1.v_stack(&m2).h_stack(&m3);
|
||||||
|
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn dot() {
|
fn dot() {
|
||||||
|
|
||||||
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
||||||
let b = DMatrix::from_row_slice(3, 2, &[1., 2., 3., 4., 5., 6.]);
|
let b = DMatrix::from_row_slice(3, 2, &[1., 2., 3., 4., 5., 6.]);
|
||||||
let expected = DMatrix::from_row_slice(2, 2, &[22., 28., 49., 64.]);
|
let expected = DMatrix::from_row_slice(2, 2, &[22., 28., 49., 64.]);
|
||||||
@@ -568,8 +530,11 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn slice() {
|
fn slice() {
|
||||||
|
let a = DMatrix::from_row_slice(
|
||||||
let a = DMatrix::from_row_slice(3, 5, &[1., 2., 3., 1., 2., 4., 5., 6., 3., 4., 7., 8., 9., 5., 6.]);
|
3,
|
||||||
|
5,
|
||||||
|
&[1., 2., 3., 1., 2., 4., 5., 6., 3., 4., 7., 8., 9., 5., 6.],
|
||||||
|
);
|
||||||
let expected = DMatrix::from_row_slice(2, 2, &[2., 3., 5., 6.]);
|
let expected = DMatrix::from_row_slice(2, 2, &[2., 3., 5., 6.]);
|
||||||
let result = BaseMatrix::slice(&a, 0..2, 1..3);
|
let result = BaseMatrix::slice(&a, 0..2, 1..3);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
@@ -578,7 +543,11 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn approximate_eq() {
|
fn approximate_eq() {
|
||||||
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
||||||
let noise = DMatrix::from_row_slice(3, 3, &[1e-5, 2e-5, 3e-5, 4e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5]);
|
let noise = DMatrix::from_row_slice(
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
&[1e-5, 2e-5, 3e-5, 4e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5],
|
||||||
|
);
|
||||||
assert!(a.approximate_eq(&(&noise + &a), 1e-4));
|
assert!(a.approximate_eq(&(&noise + &a), 1e-4));
|
||||||
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
||||||
}
|
}
|
||||||
@@ -695,5 +664,4 @@ mod tests {
|
|||||||
assert_eq!(res.len(), 7);
|
assert_eq!(res.len(), 7);
|
||||||
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,20 +1,20 @@
|
|||||||
use std::ops::Range;
|
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
use std::ops::AddAssign;
|
use std::ops::AddAssign;
|
||||||
use std::ops::SubAssign;
|
|
||||||
use std::ops::MulAssign;
|
|
||||||
use std::ops::DivAssign;
|
use std::ops::DivAssign;
|
||||||
|
use std::ops::MulAssign;
|
||||||
|
use std::ops::Range;
|
||||||
|
use std::ops::SubAssign;
|
||||||
|
|
||||||
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
|
|
||||||
use ndarray::ScalarOperand;
|
use ndarray::ScalarOperand;
|
||||||
|
use ndarray::{s, stack, Array, ArrayBase, Axis, Ix1, Ix2, OwnedRepr};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::{BaseMatrix, BaseVector};
|
|
||||||
use crate::linalg::Matrix;
|
|
||||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
|
||||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||||
use crate::linalg::qr::QRDecomposableMatrix;
|
|
||||||
use crate::linalg::lu::LUDecomposableMatrix;
|
use crate::linalg::lu::LUDecomposableMatrix;
|
||||||
|
use crate::linalg::qr::QRDecomposableMatrix;
|
||||||
|
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||||
|
use crate::linalg::Matrix;
|
||||||
|
use crate::linalg::{BaseMatrix, BaseVector};
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
impl<T: FloatExt> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix1> {
|
impl<T: FloatExt> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix1> {
|
||||||
fn get(&self, i: usize) -> T {
|
fn get(&self, i: usize) -> T {
|
||||||
@@ -29,7 +29,8 @@ impl<T: FloatExt> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix1> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||||
|
BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||||
{
|
{
|
||||||
type RowVector = ArrayBase<OwnedRepr<T>, Ix1>;
|
type RowVector = ArrayBase<OwnedRepr<T>, Ix1>;
|
||||||
|
|
||||||
@@ -152,9 +153,7 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn rand(nrows: usize, ncols: usize) -> Self {
|
fn rand(nrows: usize, ncols: usize) -> Self {
|
||||||
let values: Vec<T> = (0..nrows*ncols).map(|_| {
|
let values: Vec<T> = (0..nrows * ncols).map(|_| T::rand()).collect();
|
||||||
T::rand()
|
|
||||||
}).collect();
|
|
||||||
Array::from_shape_vec((nrows, ncols), values).unwrap()
|
Array::from_shape_vec((nrows, ncols), values).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,7 +181,6 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
let mut norm = T::zero();
|
let mut norm = T::zero();
|
||||||
|
|
||||||
for xi in self.iter() {
|
for xi in self.iter() {
|
||||||
@@ -247,7 +245,10 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn softmax_mut(&mut self) {
|
fn softmax_mut(&mut self) {
|
||||||
let max = self.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b));
|
let max = self
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.abs())
|
||||||
|
.fold(T::neg_infinity(), |a, b| a.max(b));
|
||||||
let mut z = T::zero();
|
let mut z = T::zero();
|
||||||
for r in 0..self.nrows() {
|
for r in 0..self.nrows() {
|
||||||
for c in 0..self.ncols() {
|
for c in 0..self.ncols() {
|
||||||
@@ -289,7 +290,6 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
|
|||||||
}
|
}
|
||||||
|
|
||||||
res
|
res
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unique(&self) -> Vec<T> {
|
fn unique(&self) -> Vec<T> {
|
||||||
@@ -302,18 +302,32 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
|
|||||||
fn cov(&self) -> Self {
|
fn cov(&self) -> Self {
|
||||||
panic!("Not implemented");
|
panic!("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||||
|
SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||||
|
EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||||
|
QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> LUDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||||
|
LUDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T>
|
||||||
|
for ArrayBase<OwnedRepr<T>, Ix2>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
@@ -339,126 +353,98 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn from_to_row_vec() {
|
fn from_to_row_vec() {
|
||||||
|
|
||||||
let vec = arr1(&[1., 2., 3.]);
|
let vec = arr1(&[1., 2., 3.]);
|
||||||
assert_eq!(Array2::from_row_vector(vec.clone()), arr2(&[[1., 2., 3.]]));
|
assert_eq!(Array2::from_row_vector(vec.clone()), arr2(&[[1., 2., 3.]]));
|
||||||
assert_eq!(Array2::from_row_vector(vec.clone()).to_row_vector(), arr1(&[1., 2., 3.]));
|
assert_eq!(
|
||||||
|
Array2::from_row_vector(vec.clone()).to_row_vector(),
|
||||||
|
arr1(&[1., 2., 3.])
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn add_mut() {
|
fn add_mut() {
|
||||||
|
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
let mut a1 = arr2(&[[ 1., 2., 3.],
|
|
||||||
[4., 5., 6.]]);
|
|
||||||
let a2 = a1.clone();
|
let a2 = a1.clone();
|
||||||
let a3 = a1.clone() + a2.clone();
|
let a3 = a1.clone() + a2.clone();
|
||||||
a1.add_mut(&a2);
|
a1.add_mut(&a2);
|
||||||
|
|
||||||
assert_eq!(a1, a3);
|
assert_eq!(a1, a3);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sub_mut() {
|
fn sub_mut() {
|
||||||
|
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
let mut a1 = arr2(&[[ 1., 2., 3.],
|
|
||||||
[4., 5., 6.]]);
|
|
||||||
let a2 = a1.clone();
|
let a2 = a1.clone();
|
||||||
let a3 = a1.clone() - a2.clone();
|
let a3 = a1.clone() - a2.clone();
|
||||||
a1.sub_mut(&a2);
|
a1.sub_mut(&a2);
|
||||||
|
|
||||||
assert_eq!(a1, a3);
|
assert_eq!(a1, a3);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn mul_mut() {
|
fn mul_mut() {
|
||||||
|
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
let mut a1 = arr2(&[[ 1., 2., 3.],
|
|
||||||
[4., 5., 6.]]);
|
|
||||||
let a2 = a1.clone();
|
let a2 = a1.clone();
|
||||||
let a3 = a1.clone() * a2.clone();
|
let a3 = a1.clone() * a2.clone();
|
||||||
a1.mul_mut(&a2);
|
a1.mul_mut(&a2);
|
||||||
|
|
||||||
assert_eq!(a1, a3);
|
assert_eq!(a1, a3);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn div_mut() {
|
fn div_mut() {
|
||||||
|
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
let mut a1 = arr2(&[[ 1., 2., 3.],
|
|
||||||
[4., 5., 6.]]);
|
|
||||||
let a2 = a1.clone();
|
let a2 = a1.clone();
|
||||||
let a3 = a1.clone() / a2.clone();
|
let a3 = a1.clone() / a2.clone();
|
||||||
a1.div_mut(&a2);
|
a1.div_mut(&a2);
|
||||||
|
|
||||||
assert_eq!(a1, a3);
|
assert_eq!(a1, a3);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn div_element_mut() {
|
fn div_element_mut() {
|
||||||
|
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
let mut a = arr2(&[[ 1., 2., 3.],
|
|
||||||
[4., 5., 6.]]);
|
|
||||||
a.div_element_mut(1, 1, 5.);
|
a.div_element_mut(1, 1, 5.);
|
||||||
|
|
||||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 1.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 1.);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn mul_element_mut() {
|
fn mul_element_mut() {
|
||||||
|
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
let mut a = arr2(&[[ 1., 2., 3.],
|
|
||||||
[4., 5., 6.]]);
|
|
||||||
a.mul_element_mut(1, 1, 5.);
|
a.mul_element_mut(1, 1, 5.);
|
||||||
|
|
||||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 25.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 25.);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn add_element_mut() {
|
fn add_element_mut() {
|
||||||
|
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
let mut a = arr2(&[[ 1., 2., 3.],
|
|
||||||
[4., 5., 6.]]);
|
|
||||||
a.add_element_mut(1, 1, 5.);
|
a.add_element_mut(1, 1, 5.);
|
||||||
|
|
||||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 10.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 10.);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sub_element_mut() {
|
fn sub_element_mut() {
|
||||||
|
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
let mut a = arr2(&[[ 1., 2., 3.],
|
|
||||||
[4., 5., 6.]]);
|
|
||||||
a.sub_element_mut(1, 1, 5.);
|
a.sub_element_mut(1, 1, 5.);
|
||||||
|
|
||||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 0.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 0.);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn vstack_hstack() {
|
fn vstack_hstack() {
|
||||||
|
let a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
let a1 = arr2(&[[1., 2., 3.],
|
|
||||||
[4., 5., 6.]]);
|
|
||||||
let a2 = arr2(&[[7.], [8.]]);
|
let a2 = arr2(&[[7.], [8.]]);
|
||||||
|
|
||||||
let a3 = arr2(&[[9., 10., 11., 12.]]);
|
let a3 = arr2(&[[9., 10., 11., 12.]]);
|
||||||
|
|
||||||
let expected = arr2(&[[1., 2., 3., 7.],
|
let expected = arr2(&[[1., 2., 3., 7.], [4., 5., 6., 8.], [9., 10., 11., 12.]]);
|
||||||
[4., 5., 6., 8.],
|
|
||||||
[9., 10., 11., 12.]]);
|
|
||||||
|
|
||||||
let result = a1.v_stack(&a2).h_stack(&a3);
|
let result = a1.v_stack(&a2).h_stack(&a3);
|
||||||
|
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -482,17 +468,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn dot() {
|
fn dot() {
|
||||||
|
let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
let a = arr2(&[
|
let b = arr2(&[[1., 2.], [3., 4.], [5., 6.]]);
|
||||||
[1., 2., 3.],
|
let expected = arr2(&[[22., 28.], [49., 64.]]);
|
||||||
[4., 5., 6.]]);
|
|
||||||
let b = arr2(&[
|
|
||||||
[1., 2.],
|
|
||||||
[3., 4.],
|
|
||||||
[5., 6.]]);
|
|
||||||
let expected = arr2(&[
|
|
||||||
[22., 28.],
|
|
||||||
[49., 64.]]);
|
|
||||||
let result = BaseMatrix::dot(&a, &b);
|
let result = BaseMatrix::dot(&a, &b);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
@@ -506,16 +484,12 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn slice() {
|
fn slice() {
|
||||||
|
let a = arr2(&[
|
||||||
let a = arr2(
|
|
||||||
&[
|
|
||||||
[1., 2., 3., 1., 2.],
|
[1., 2., 3., 1., 2.],
|
||||||
[4., 5., 6., 3., 4.],
|
[4., 5., 6., 3., 4.],
|
||||||
[7., 8., 9., 5., 6.]]);
|
[7., 8., 9., 5., 6.],
|
||||||
let expected = arr2(
|
]);
|
||||||
&[
|
let expected = arr2(&[[2., 3.], [5., 6.]]);
|
||||||
[2., 3.],
|
|
||||||
[5., 6.]]);
|
|
||||||
let result = BaseMatrix::slice(&a, 0..2, 1..3);
|
let result = BaseMatrix::slice(&a, 0..2, 1..3);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
@@ -633,18 +607,14 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn col_mean() {
|
fn col_mean() {
|
||||||
let a = arr2(&[[1., 2., 3.],
|
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||||
[4., 5., 6.],
|
|
||||||
[7., 8., 9.]]);
|
|
||||||
let res = a.column_mean();
|
let res = a.column_mean();
|
||||||
assert_eq!(res, vec![4., 5., 6.]);
|
assert_eq!(res, vec![4., 5., 6.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn eye() {
|
fn eye() {
|
||||||
let a = arr2(&[[1., 0., 0.],
|
let a = arr2(&[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]);
|
||||||
[0., 1., 0.],
|
|
||||||
[0., 0., 1.]]);
|
|
||||||
let res: Array2<f64> = BaseMatrix::eye(3);
|
let res: Array2<f64> = BaseMatrix::eye(3);
|
||||||
assert_eq!(res, a);
|
assert_eq!(res, a);
|
||||||
}
|
}
|
||||||
@@ -661,12 +631,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn approximate_eq() {
|
fn approximate_eq() {
|
||||||
let a = arr2(&[[1., 2., 3.],
|
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||||
[4., 5., 6.],
|
let noise = arr2(&[[1e-5, 2e-5, 3e-5], [4e-5, 5e-5, 6e-5], [7e-5, 8e-5, 9e-5]]);
|
||||||
[7., 8., 9.]]);
|
|
||||||
let noise = arr2(&[[1e-5, 2e-5, 3e-5],
|
|
||||||
[4e-5, 5e-5, 6e-5],
|
|
||||||
[7e-5, 8e-5, 9e-5]]);
|
|
||||||
assert!(a.approximate_eq(&(&noise + &a), 1e-4));
|
assert!(a.approximate_eq(&(&noise + &a), 1e-4));
|
||||||
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
||||||
}
|
}
|
||||||
|
|||||||
+12
-18
@@ -2,19 +2,18 @@
|
|||||||
|
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::BaseMatrix;
|
use crate::linalg::BaseMatrix;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct QR<T: FloatExt, M: BaseMatrix<T>> {
|
pub struct QR<T: FloatExt, M: BaseMatrix<T>> {
|
||||||
QR: M,
|
QR: M,
|
||||||
tau: Vec<T>,
|
tau: Vec<T>,
|
||||||
singular: bool
|
singular: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
|
impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
|
||||||
pub fn new(QR: M, tau: Vec<T>) -> QR<T, M> {
|
pub fn new(QR: M, tau: Vec<T>) -> QR<T, M> {
|
||||||
|
|
||||||
let mut singular = false;
|
let mut singular = false;
|
||||||
for j in 0..tau.len() {
|
for j in 0..tau.len() {
|
||||||
if tau[j] == T::zero() {
|
if tau[j] == T::zero() {
|
||||||
@@ -26,7 +25,7 @@ impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
|
|||||||
QR {
|
QR {
|
||||||
QR: QR,
|
QR: QR,
|
||||||
tau: tau,
|
tau: tau,
|
||||||
singular: singular
|
singular: singular,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,12 +69,14 @@ impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn solve(&self, mut b: M) -> M {
|
fn solve(&self, mut b: M) -> M {
|
||||||
|
|
||||||
let (m, n) = self.QR.shape();
|
let (m, n) = self.QR.shape();
|
||||||
let (b_nrows, b_ncols) = b.shape();
|
let (b_nrows, b_ncols) = b.shape();
|
||||||
|
|
||||||
if b_nrows != m {
|
if b_nrows != m {
|
||||||
panic!("Row dimensions do not agree: A is {} x {}, but B is {} x {}", m, n, b_nrows, b_ncols);
|
panic!(
|
||||||
|
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
|
||||||
|
m, n, b_nrows, b_ncols
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.singular {
|
if self.singular {
|
||||||
@@ -108,18 +109,15 @@ impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b
|
b
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait QRDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
pub trait QRDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||||
|
|
||||||
fn qr(&self) -> QR<T, Self> {
|
fn qr(&self) -> QR<T, Self> {
|
||||||
self.clone().qr_mut()
|
self.clone().qr_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn qr_mut(mut self) -> QR<T, Self> {
|
fn qr_mut(mut self) -> QR<T, Self> {
|
||||||
|
|
||||||
let (m, n) = self.shape();
|
let (m, n) = self.shape();
|
||||||
|
|
||||||
let mut r_diagonal: Vec<T> = vec![T::zero(); n];
|
let mut r_diagonal: Vec<T> = vec![T::zero(); n];
|
||||||
@@ -131,7 +129,6 @@ pub trait QRDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nrm.abs() > T::epsilon() {
|
if nrm.abs() > T::epsilon() {
|
||||||
|
|
||||||
if self.get(k, k) < T::zero() {
|
if self.get(k, k) < T::zero() {
|
||||||
nrm = -nrm;
|
nrm = -nrm;
|
||||||
}
|
}
|
||||||
@@ -155,13 +152,10 @@ pub trait QRDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
QR::new(self, r_diagonal)
|
QR::new(self, r_diagonal)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn qr_solve_mut(self, b: Self) -> Self {
|
fn qr_solve_mut(self, b: Self) -> Self {
|
||||||
|
|
||||||
self.qr_mut().solve(b)
|
self.qr_mut().solve(b)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,16 +166,17 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose() {
|
fn decompose() {
|
||||||
|
|
||||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let q = DenseMatrix::from_array(&[
|
let q = DenseMatrix::from_array(&[
|
||||||
&[-0.7448, 0.2436, 0.6212],
|
&[-0.7448, 0.2436, 0.6212],
|
||||||
&[-0.331, -0.9432, -0.027],
|
&[-0.331, -0.9432, -0.027],
|
||||||
&[-0.5793, 0.2257, -0.7832]]);
|
&[-0.5793, 0.2257, -0.7832],
|
||||||
|
]);
|
||||||
let r = DenseMatrix::from_array(&[
|
let r = DenseMatrix::from_array(&[
|
||||||
&[-1.2083, -0.6373, -1.0842],
|
&[-1.2083, -0.6373, -1.0842],
|
||||||
&[0.0, -0.3064, 0.0682],
|
&[0.0, -0.3064, 0.0682],
|
||||||
&[0.0, 0.0, -0.1999]]);
|
&[0.0, 0.0, -0.1999],
|
||||||
|
]);
|
||||||
let qr = a.qr();
|
let qr = a.qr();
|
||||||
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
|
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
|
||||||
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
|
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
|
||||||
@@ -189,13 +184,12 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn qr_solve_mut() {
|
fn qr_solve_mut() {
|
||||||
|
|
||||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let b = DenseMatrix::from_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
let b = DenseMatrix::from_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
||||||
let expected_w = DenseMatrix::from_array(&[
|
let expected_w = DenseMatrix::from_array(&[
|
||||||
&[-0.2027027, -1.2837838],
|
&[-0.2027027, -1.2837838],
|
||||||
&[0.8783784, 2.2297297],
|
&[0.8783784, 2.2297297],
|
||||||
&[0.4729730, 0.6621622]
|
&[0.4729730, 0.6621622],
|
||||||
]);
|
]);
|
||||||
let w = a.qr_solve_mut(b);
|
let w = a.qr_solve_mut(b);
|
||||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
|
|||||||
+192
-45
@@ -12,11 +12,10 @@ pub struct SVD<T: FloatExt, M: SVDDecomposableMatrix<T>> {
|
|||||||
full: bool,
|
full: bool,
|
||||||
m: usize,
|
m: usize,
|
||||||
n: usize,
|
n: usize,
|
||||||
tol: T
|
tol: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||||
|
|
||||||
fn svd_solve_mut(self, b: Self) -> Self {
|
fn svd_solve_mut(self, b: Self) -> Self {
|
||||||
self.svd_mut().solve(b)
|
self.svd_mut().solve(b)
|
||||||
}
|
}
|
||||||
@@ -30,7 +29,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn svd_mut(self) -> SVD<T, Self> {
|
fn svd_mut(self) -> SVD<T, Self> {
|
||||||
|
|
||||||
let mut U = self;
|
let mut U = self;
|
||||||
|
|
||||||
let (m, n) = U.shape();
|
let (m, n) = U.shape();
|
||||||
@@ -55,7 +53,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if scale.abs() > T::epsilon() {
|
if scale.abs() > T::epsilon() {
|
||||||
|
|
||||||
for k in i..m {
|
for k in i..m {
|
||||||
U.div_element_mut(k, i, scale);
|
U.div_element_mut(k, i, scale);
|
||||||
s = s + U.get(k, i) * U.get(k, i);
|
s = s + U.get(k, i) * U.get(k, i);
|
||||||
@@ -123,7 +120,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
anorm = T::max(anorm, w[i].abs() + rv1[i].abs());
|
anorm = T::max(anorm, w[i].abs() + rv1[i].abs());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -339,7 +335,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
for k in 0..n {
|
for k in 0..n {
|
||||||
v.set(k, j, sv[k]);
|
v.set(k, j, sv[k]);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
if inc <= 1 {
|
if inc <= 1 {
|
||||||
break;
|
break;
|
||||||
@@ -369,7 +364,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SVD::new(U, v, w)
|
SVD::new(U, v, w)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -386,7 +380,7 @@ impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
|||||||
full: full,
|
full: full,
|
||||||
m: m,
|
m: m,
|
||||||
n: n,
|
n: n,
|
||||||
tol: tol
|
tol: tol,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -394,7 +388,11 @@ impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
|||||||
let p = b.shape().1;
|
let p = b.shape().1;
|
||||||
|
|
||||||
if self.U.shape().0 != b.shape().0 {
|
if self.U.shape().0 != b.shape().0 {
|
||||||
panic!("Dimensions do not agree. U.nrows should equal b.nrows but is {}, {}", self.U.shape().0, b.shape().0);
|
panic!(
|
||||||
|
"Dimensions do not agree. U.nrows should equal b.nrows but is {}, {}",
|
||||||
|
self.U.shape().0,
|
||||||
|
b.shape().0
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in 0..p {
|
for k in 0..p {
|
||||||
@@ -430,24 +428,24 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_symmetric() {
|
fn decompose_symmetric() {
|
||||||
|
|
||||||
let A = DenseMatrix::from_array(&[
|
let A = DenseMatrix::from_array(&[
|
||||||
&[0.9000, 0.4000, 0.7000],
|
&[0.9000, 0.4000, 0.7000],
|
||||||
&[0.4000, 0.5000, 0.3000],
|
&[0.4000, 0.5000, 0.3000],
|
||||||
&[0.7000, 0.3000, 0.8000]]);
|
&[0.7000, 0.3000, 0.8000],
|
||||||
|
]);
|
||||||
|
|
||||||
let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
|
let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
|
||||||
|
|
||||||
let U = DenseMatrix::from_array(&[
|
let U = DenseMatrix::from_array(&[
|
||||||
&[0.6881997, -0.07121225, 0.7220180],
|
&[0.6881997, -0.07121225, 0.7220180],
|
||||||
&[0.3700456, 0.89044952, -0.2648886],
|
&[0.3700456, 0.89044952, -0.2648886],
|
||||||
&[0.6240573, -0.44947578, -0.639158]
|
&[0.6240573, -0.44947578, -0.639158],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let V = DenseMatrix::from_array(&[
|
let V = DenseMatrix::from_array(&[
|
||||||
&[0.6881997, -0.07121225, 0.7220180],
|
&[0.6881997, -0.07121225, 0.7220180],
|
||||||
&[0.3700456, 0.89044952, -0.2648886],
|
&[0.3700456, 0.89044952, -0.2648886],
|
||||||
&[0.6240573, -0.44947578, -0.6391588]
|
&[0.6240573, -0.44947578, -0.6391588],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let svd = A.svd();
|
let svd = A.svd();
|
||||||
@@ -457,42 +455,198 @@ mod tests {
|
|||||||
for i in 0..s.len() {
|
for i in 0..s.len() {
|
||||||
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_asymmetric() {
|
fn decompose_asymmetric() {
|
||||||
|
|
||||||
let A = DenseMatrix::from_array(&[
|
let A = DenseMatrix::from_array(&[
|
||||||
&[1.19720880, -1.8391378, 0.3019585, -1.1165701, -1.7210814, 0.4918882, -0.04247433],
|
&[
|
||||||
&[0.06605075, 1.0315583, 0.8294362, -0.3646043, -1.6038017, -0.9188110, -0.63760340],
|
1.19720880,
|
||||||
&[-1.02637715, 1.0747931, -0.8089055, -0.4726863, -0.2064826, -0.3325532, 0.17966051],
|
-1.8391378,
|
||||||
&[-1.45817729, -0.8942353, 0.3459245, 1.5068363, -2.0180708, -0.3696350, -1.19575563],
|
0.3019585,
|
||||||
&[-0.07318103, -0.2783787, 1.2237598, 0.1995332, 0.2545336, -0.1392502, -1.88207227],
|
-1.1165701,
|
||||||
&[0.88248425, -0.9360321, 0.1393172, 0.1393281, -0.3277873, -0.5553013, 1.63805985],
|
-1.7210814,
|
||||||
&[0.12641406, -0.8710055, -0.2712301, 0.2296515, 1.1781535, -0.2158704, -0.27529472]
|
0.4918882,
|
||||||
|
-0.04247433,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.06605075,
|
||||||
|
1.0315583,
|
||||||
|
0.8294362,
|
||||||
|
-0.3646043,
|
||||||
|
-1.6038017,
|
||||||
|
-0.9188110,
|
||||||
|
-0.63760340,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-1.02637715,
|
||||||
|
1.0747931,
|
||||||
|
-0.8089055,
|
||||||
|
-0.4726863,
|
||||||
|
-0.2064826,
|
||||||
|
-0.3325532,
|
||||||
|
0.17966051,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-1.45817729,
|
||||||
|
-0.8942353,
|
||||||
|
0.3459245,
|
||||||
|
1.5068363,
|
||||||
|
-2.0180708,
|
||||||
|
-0.3696350,
|
||||||
|
-1.19575563,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.07318103,
|
||||||
|
-0.2783787,
|
||||||
|
1.2237598,
|
||||||
|
0.1995332,
|
||||||
|
0.2545336,
|
||||||
|
-0.1392502,
|
||||||
|
-1.88207227,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.88248425, -0.9360321, 0.1393172, 0.1393281, -0.3277873, -0.5553013, 1.63805985,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.12641406,
|
||||||
|
-0.8710055,
|
||||||
|
-0.2712301,
|
||||||
|
0.2296515,
|
||||||
|
1.1781535,
|
||||||
|
-0.2158704,
|
||||||
|
-0.27529472,
|
||||||
|
],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let s: Vec<f64> = vec![3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515];
|
let s: Vec<f64> = vec![
|
||||||
|
3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515,
|
||||||
|
];
|
||||||
|
|
||||||
let U = DenseMatrix::from_array(&[
|
let U = DenseMatrix::from_array(&[
|
||||||
&[-0.3082776, 0.77676231, 0.01330514, 0.23231424, -0.47682758, 0.13927109, 0.02640713],
|
&[
|
||||||
&[-0.4013477, -0.09112050, 0.48754440, 0.47371793, 0.40636608, 0.24600706, -0.37796295],
|
-0.3082776,
|
||||||
&[0.0599719, -0.31406586, 0.45428229, -0.08071283, -0.38432597, 0.57320261, 0.45673993],
|
0.77676231,
|
||||||
&[-0.7694214, -0.12681435, -0.05536793, -0.62189972, -0.02075522, -0.01724911, -0.03681864],
|
0.01330514,
|
||||||
&[-0.3319069, -0.17984404, -0.54466777, 0.45335157, 0.19377726, 0.12333423, 0.55003852],
|
0.23231424,
|
||||||
&[0.1259351, 0.49087824, 0.16349687, -0.32080176, 0.64828744, 0.20643772, 0.38812467],
|
-0.47682758,
|
||||||
&[0.1491884, 0.01768604, -0.47884363, -0.14108924, 0.03922507, 0.73034065, -0.43965505]
|
0.13927109,
|
||||||
|
0.02640713,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.4013477,
|
||||||
|
-0.09112050,
|
||||||
|
0.48754440,
|
||||||
|
0.47371793,
|
||||||
|
0.40636608,
|
||||||
|
0.24600706,
|
||||||
|
-0.37796295,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.0599719,
|
||||||
|
-0.31406586,
|
||||||
|
0.45428229,
|
||||||
|
-0.08071283,
|
||||||
|
-0.38432597,
|
||||||
|
0.57320261,
|
||||||
|
0.45673993,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.7694214,
|
||||||
|
-0.12681435,
|
||||||
|
-0.05536793,
|
||||||
|
-0.62189972,
|
||||||
|
-0.02075522,
|
||||||
|
-0.01724911,
|
||||||
|
-0.03681864,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.3319069,
|
||||||
|
-0.17984404,
|
||||||
|
-0.54466777,
|
||||||
|
0.45335157,
|
||||||
|
0.19377726,
|
||||||
|
0.12333423,
|
||||||
|
0.55003852,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.1259351,
|
||||||
|
0.49087824,
|
||||||
|
0.16349687,
|
||||||
|
-0.32080176,
|
||||||
|
0.64828744,
|
||||||
|
0.20643772,
|
||||||
|
0.38812467,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.1491884,
|
||||||
|
0.01768604,
|
||||||
|
-0.47884363,
|
||||||
|
-0.14108924,
|
||||||
|
0.03922507,
|
||||||
|
0.73034065,
|
||||||
|
-0.43965505,
|
||||||
|
],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let V = DenseMatrix::from_array(&[
|
let V = DenseMatrix::from_array(&[
|
||||||
&[-0.2122609, -0.54650056, 0.08071332, -0.43239135, -0.2925067, 0.1414550, 0.59769207],
|
&[
|
||||||
&[-0.1943605, 0.63132116, -0.54059857, -0.37089970, -0.1363031, 0.2892641, 0.17774114],
|
-0.2122609,
|
||||||
&[0.3031265, -0.06182488, 0.18579097, -0.38606409, -0.5364911, 0.2983466, -0.58642548],
|
-0.54650056,
|
||||||
&[0.1844063, 0.24425278, 0.25923756, 0.59043765, -0.4435443, 0.3959057, 0.37019098],
|
0.08071332,
|
||||||
&[-0.7164205, 0.30694911, 0.58264743, -0.07458095, -0.1142140, -0.1311972, -0.13124764],
|
-0.43239135,
|
||||||
&[-0.1103067, -0.10633600, 0.18257905, -0.03638501, 0.5722925, 0.7784398, -0.09153611],
|
-0.2925067,
|
||||||
&[-0.5156083, -0.36573746, -0.47613340, 0.41342817, -0.2659765, 0.1654796, -0.32346758]
|
0.1414550,
|
||||||
|
0.59769207,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.1943605,
|
||||||
|
0.63132116,
|
||||||
|
-0.54059857,
|
||||||
|
-0.37089970,
|
||||||
|
-0.1363031,
|
||||||
|
0.2892641,
|
||||||
|
0.17774114,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.3031265,
|
||||||
|
-0.06182488,
|
||||||
|
0.18579097,
|
||||||
|
-0.38606409,
|
||||||
|
-0.5364911,
|
||||||
|
0.2983466,
|
||||||
|
-0.58642548,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
0.1844063, 0.24425278, 0.25923756, 0.59043765, -0.4435443, 0.3959057, 0.37019098,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.7164205,
|
||||||
|
0.30694911,
|
||||||
|
0.58264743,
|
||||||
|
-0.07458095,
|
||||||
|
-0.1142140,
|
||||||
|
-0.1311972,
|
||||||
|
-0.13124764,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.1103067,
|
||||||
|
-0.10633600,
|
||||||
|
0.18257905,
|
||||||
|
-0.03638501,
|
||||||
|
0.5722925,
|
||||||
|
0.7784398,
|
||||||
|
-0.09153611,
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-0.5156083,
|
||||||
|
-0.36573746,
|
||||||
|
-0.47613340,
|
||||||
|
0.41342817,
|
||||||
|
-0.2659765,
|
||||||
|
0.1654796,
|
||||||
|
-0.32346758,
|
||||||
|
],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let svd = A.svd();
|
let svd = A.svd();
|
||||||
@@ -502,21 +656,14 @@ mod tests {
|
|||||||
for i in 0..s.len() {
|
for i in 0..s.len() {
|
||||||
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn solve() {
|
fn solve() {
|
||||||
|
|
||||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let b = DenseMatrix::from_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
let b = DenseMatrix::from_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
||||||
let expected_w = DenseMatrix::from_array(&[
|
let expected_w = DenseMatrix::from_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
|
||||||
&[-0.20, -1.28],
|
|
||||||
&[0.87, 2.22],
|
|
||||||
&[0.47, 0.66]
|
|
||||||
]);
|
|
||||||
let w = a.svd_solve_mut(b);
|
let w = a.svd_solve_mut(b);
|
||||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,34 +1,32 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub enum LinearRegressionSolver {
|
pub enum LinearRegressionSolver {
|
||||||
QR,
|
QR,
|
||||||
SVD
|
SVD,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct LinearRegression<T: FloatExt, M: Matrix<T>> {
|
pub struct LinearRegression<T: FloatExt, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
intercept: T,
|
intercept: T,
|
||||||
solver: LinearRegressionSolver
|
solver: LinearRegressionSolver,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.coefficients == other.coefficients &&
|
self.coefficients == other.coefficients
|
||||||
(self.intercept - other.intercept).abs() <= T::epsilon()
|
&& (self.intercept - other.intercept).abs() <= T::epsilon()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
||||||
|
|
||||||
pub fn fit(x: &M, y: &M::RowVector, solver: LinearRegressionSolver) -> LinearRegression<T, M> {
|
pub fn fit(x: &M, y: &M::RowVector, solver: LinearRegressionSolver) -> LinearRegression<T, M> {
|
||||||
|
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let b = y_m.transpose();
|
let b = y_m.transpose();
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
@@ -42,7 +40,7 @@ impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
|||||||
|
|
||||||
let w = match solver {
|
let w = match solver {
|
||||||
LinearRegressionSolver::QR => a.qr_solve_mut(b),
|
LinearRegressionSolver::QR => a.qr_solve_mut(b),
|
||||||
LinearRegressionSolver::SVD => a.svd_solve_mut(b)
|
LinearRegressionSolver::SVD => a.svd_solve_mut(b),
|
||||||
};
|
};
|
||||||
|
|
||||||
let wights = w.slice(0..num_attributes, 0..1);
|
let wights = w.slice(0..num_attributes, 0..1);
|
||||||
@@ -50,7 +48,7 @@ impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
|||||||
LinearRegression {
|
LinearRegression {
|
||||||
intercept: w.get(num_attributes, 0),
|
intercept: w.get(num_attributes, 0),
|
||||||
coefficients: wights,
|
coefficients: wights,
|
||||||
solver: solver
|
solver: solver,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,50 +58,54 @@ impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
|||||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||||
y_hat.transpose().to_row_vector()
|
y_hat.transpose().to_row_vector()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use nalgebra::{DMatrix, RowDVector};
|
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
use nalgebra::{DMatrix, RowDVector};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ols_fit_predict() {
|
fn ols_fit_predict() {
|
||||||
|
let x = DMatrix::from_row_slice(
|
||||||
|
16,
|
||||||
|
6,
|
||||||
|
&[
|
||||||
|
234.289, 235.6, 159.0, 107.608, 1947., 60.323, 259.426, 232.5, 145.6, 108.632,
|
||||||
|
1948., 61.122, 258.054, 368.2, 161.6, 109.773, 1949., 60.171, 284.599, 335.1,
|
||||||
|
165.0, 110.929, 1950., 61.187, 328.975, 209.9, 309.9, 112.075, 1951., 63.221,
|
||||||
|
346.999, 193.2, 359.4, 113.270, 1952., 63.639, 365.385, 187.0, 354.7, 115.094,
|
||||||
|
1953., 64.989, 363.112, 357.8, 335.0, 116.219, 1954., 63.761, 397.469, 290.4,
|
||||||
|
304.8, 117.388, 1955., 66.019, 419.180, 282.2, 285.7, 118.734, 1956., 67.857,
|
||||||
|
442.769, 293.6, 279.8, 120.445, 1957., 68.169, 444.546, 468.1, 263.7, 121.950,
|
||||||
|
1958., 66.513, 482.704, 381.3, 255.2, 123.366, 1959., 68.655, 502.601, 393.1,
|
||||||
|
251.4, 125.368, 1960., 69.564, 518.173, 480.6, 257.2, 127.852, 1961., 69.331,
|
||||||
|
554.894, 400.7, 282.7, 130.081, 1962., 70.551,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
let x = DMatrix::from_row_slice(16, 6, &[
|
let y: RowDVector<f64> = RowDVector::from_vec(vec![
|
||||||
234.289, 235.6, 159.0, 107.608, 1947., 60.323,
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
259.426, 232.5, 145.6, 108.632, 1948., 61.122,
|
114.2, 115.7, 116.9,
|
||||||
258.054, 368.2, 161.6, 109.773, 1949., 60.171,
|
]);
|
||||||
284.599, 335.1, 165.0, 110.929, 1950., 61.187,
|
|
||||||
328.975, 209.9, 309.9, 112.075, 1951., 63.221,
|
|
||||||
346.999, 193.2, 359.4, 113.270, 1952., 63.639,
|
|
||||||
365.385, 187.0, 354.7, 115.094, 1953., 64.989,
|
|
||||||
363.112, 357.8, 335.0, 116.219, 1954., 63.761,
|
|
||||||
397.469, 290.4, 304.8, 117.388, 1955., 66.019,
|
|
||||||
419.180, 282.2, 285.7, 118.734, 1956., 67.857,
|
|
||||||
442.769, 293.6, 279.8, 120.445, 1957., 68.169,
|
|
||||||
444.546, 468.1, 263.7, 121.950, 1958., 66.513,
|
|
||||||
482.704, 381.3, 255.2, 123.366, 1959., 68.655,
|
|
||||||
502.601, 393.1, 251.4, 125.368, 1960., 69.564,
|
|
||||||
518.173, 480.6, 257.2, 127.852, 1961., 69.331,
|
|
||||||
554.894, 400.7, 282.7, 130.081, 1962., 70.551]);
|
|
||||||
|
|
||||||
let y: RowDVector<f64> = RowDVector::from_vec(vec!(83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9));
|
|
||||||
|
|
||||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
||||||
|
|
||||||
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
|
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
|
||||||
|
|
||||||
assert!(y.iter().zip(y_hat_qr.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0));
|
assert!(y
|
||||||
assert!(y.iter().zip(y_hat_svd.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0));
|
.iter()
|
||||||
|
.zip(y_hat_qr.iter())
|
||||||
|
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||||
|
assert!(y
|
||||||
|
.iter()
|
||||||
|
.zip(y_hat_svd.iter())
|
||||||
|
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ols_fit_predict_nalgebra() {
|
fn ols_fit_predict_nalgebra() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
@@ -120,17 +122,26 @@ mod tests {
|
|||||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
|
]);
|
||||||
|
|
||||||
let y: Vec<f64> = vec!(83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9);
|
let y: Vec<f64> = vec![
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
];
|
||||||
|
|
||||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
||||||
|
|
||||||
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
|
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
|
||||||
|
|
||||||
assert!(y.iter().zip(y_hat_qr.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0));
|
assert!(y
|
||||||
assert!(y.iter().zip(y_hat_svd.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0));
|
.iter()
|
||||||
|
.zip(y_hat_qr.iter())
|
||||||
|
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||||
|
assert!(y
|
||||||
|
.iter()
|
||||||
|
.zip(y_hat_svd.iter())
|
||||||
|
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -151,13 +162,18 @@ mod tests {
|
|||||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
|
]);
|
||||||
|
|
||||||
let y = vec!(83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9);
|
let y = vec![
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
];
|
||||||
|
|
||||||
let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR);
|
let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR);
|
||||||
|
|
||||||
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(lr, deserialized_lr);
|
assert_eq!(lr, deserialized_lr);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::optimization::FunctionOrder;
|
use crate::math::num::FloatExt;
|
||||||
|
use crate::optimization::first_order::lbfgs::LBFGS;
|
||||||
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||||
use crate::optimization::line_search::Backtracking;
|
use crate::optimization::line_search::Backtracking;
|
||||||
use crate::optimization::first_order::lbfgs::LBFGS;
|
use crate::optimization::FunctionOrder;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct LogisticRegression<T: FloatExt, M: Matrix<T>> {
|
pub struct LogisticRegression<T: FloatExt, M: Matrix<T>> {
|
||||||
weights: M,
|
weights: M,
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
num_attributes: usize,
|
num_attributes: usize,
|
||||||
num_classes: usize
|
num_classes: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
|
trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
|
||||||
@@ -36,31 +36,29 @@ trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
|
|||||||
struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
||||||
x: &'a M,
|
x: &'a M,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
phantom: PhantomData<&'a T>
|
phantom: PhantomData<&'a T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.num_classes != other.num_classes
|
||||||
if self.num_classes != other.num_classes ||
|
|| self.num_attributes != other.num_attributes
|
||||||
self.num_attributes != other.num_attributes ||
|
|| self.classes.len() != other.classes.len()
|
||||||
self.classes.len() != other.classes.len() {
|
{
|
||||||
return false
|
return false;
|
||||||
} else {
|
} else {
|
||||||
for i in 0..self.classes.len() {
|
for i in 0..self.classes.len() {
|
||||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.weights == other.weights
|
return self.weights == other.weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
|
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
|
||||||
|
|
||||||
fn f(&self, w_bias: &M) -> T {
|
fn f(&self, w_bias: &M) -> T {
|
||||||
let mut f = T::zero();
|
let mut f = T::zero();
|
||||||
let (n, _) = self.x.shape();
|
let (n, _) = self.x.shape();
|
||||||
@@ -74,13 +72,11 @@ impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveF
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn df(&self, g: &mut M, w_bias: &M) {
|
fn df(&self, g: &mut M, w_bias: &M) {
|
||||||
|
|
||||||
g.copy_from(&M::zeros(1, g.shape().1));
|
g.copy_from(&M::zeros(1, g.shape().1));
|
||||||
|
|
||||||
let (n, p) = self.x.shape();
|
let (n, p) = self.x.shape();
|
||||||
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
|
|
||||||
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
|
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
|
||||||
|
|
||||||
let dyi = (T::from(self.y[i]).unwrap()) - wx.sigmoid();
|
let dyi = (T::from(self.y[i]).unwrap()) - wx.sigmoid();
|
||||||
@@ -89,27 +85,30 @@ impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveF
|
|||||||
}
|
}
|
||||||
g.set(0, p, g.get(0, p) - dyi);
|
g.set(0, p, g.get(0, p) - dyi);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MultiClassObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
struct MultiClassObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
||||||
x: &'a M,
|
x: &'a M,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
k: usize,
|
k: usize,
|
||||||
phantom: PhantomData<&'a T>
|
phantom: PhantomData<&'a T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObjectiveFunction<'a, T, M> {
|
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||||
|
for MultiClassObjectiveFunction<'a, T, M>
|
||||||
|
{
|
||||||
fn f(&self, w_bias: &M) -> T {
|
fn f(&self, w_bias: &M) -> T {
|
||||||
let mut f = T::zero();
|
let mut f = T::zero();
|
||||||
let mut prob = M::zeros(1, self.k);
|
let mut prob = M::zeros(1, self.k);
|
||||||
let (n, p) = self.x.shape();
|
let (n, p) = self.x.shape();
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
for j in 0..self.k {
|
for j in 0..self.k {
|
||||||
prob.set(0, j, MultiClassObjectiveFunction::partial_dot(w_bias, self.x, j * (p + 1), i));
|
prob.set(
|
||||||
|
0,
|
||||||
|
j,
|
||||||
|
MultiClassObjectiveFunction::partial_dot(w_bias, self.x, j * (p + 1), i),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
prob.softmax_mut();
|
prob.softmax_mut();
|
||||||
f = f - prob.get(0, self.y[i]).ln();
|
f = f - prob.get(0, self.y[i]).ln();
|
||||||
@@ -119,7 +118,6 @@ impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObject
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn df(&self, g: &mut M, w: &M) {
|
fn df(&self, g: &mut M, w: &M) {
|
||||||
|
|
||||||
g.copy_from(&M::zeros(1, g.shape().1));
|
g.copy_from(&M::zeros(1, g.shape().1));
|
||||||
|
|
||||||
let mut prob = M::zeros(1, self.k);
|
let mut prob = M::zeros(1, self.k);
|
||||||
@@ -127,7 +125,11 @@ impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObject
|
|||||||
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
for j in 0..self.k {
|
for j in 0..self.k {
|
||||||
prob.set(0, j, MultiClassObjectiveFunction::partial_dot(w, self.x, j * (p + 1), i));
|
prob.set(
|
||||||
|
0,
|
||||||
|
j,
|
||||||
|
MultiClassObjectiveFunction::partial_dot(w, self.x, j * (p + 1), i),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
prob.softmax_mut();
|
prob.softmax_mut();
|
||||||
@@ -142,15 +144,11 @@ impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObject
|
|||||||
g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi);
|
g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
||||||
|
|
||||||
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M> {
|
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M> {
|
||||||
|
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let (_, y_nrows) = y_m.shape();
|
let (_, y_nrows) = y_m.shape();
|
||||||
@@ -171,17 +169,14 @@ impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if k < 2 {
|
if k < 2 {
|
||||||
|
|
||||||
panic!("Incorrect number of classes: {}", k);
|
panic!("Incorrect number of classes: {}", k);
|
||||||
|
|
||||||
} else if k == 2 {
|
} else if k == 2 {
|
||||||
|
|
||||||
let x0 = M::zeros(1, num_attributes + 1);
|
let x0 = M::zeros(1, num_attributes + 1);
|
||||||
|
|
||||||
let objective = BinaryObjectiveFunction {
|
let objective = BinaryObjectiveFunction {
|
||||||
x: x,
|
x: x,
|
||||||
y: yi,
|
y: yi,
|
||||||
phantom: PhantomData
|
phantom: PhantomData,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = LogisticRegression::minimize(x0, objective);
|
let result = LogisticRegression::minimize(x0, objective);
|
||||||
@@ -192,16 +187,14 @@ impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
num_attributes: num_attributes,
|
num_attributes: num_attributes,
|
||||||
num_classes: k,
|
num_classes: k,
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
let x0 = M::zeros(1, (num_attributes + 1) * k);
|
let x0 = M::zeros(1, (num_attributes + 1) * k);
|
||||||
|
|
||||||
let objective = MultiClassObjectiveFunction {
|
let objective = MultiClassObjectiveFunction {
|
||||||
x: x,
|
x: x,
|
||||||
y: yi,
|
y: yi,
|
||||||
k: k,
|
k: k,
|
||||||
phantom: PhantomData
|
phantom: PhantomData,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = LogisticRegression::minimize(x0, objective);
|
let result = LogisticRegression::minimize(x0, objective);
|
||||||
@@ -212,11 +205,9 @@ impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
weights: weights,
|
weights: weights,
|
||||||
classes: classes,
|
classes: classes,
|
||||||
num_attributes: num_attributes,
|
num_attributes: num_attributes,
|
||||||
num_classes: k
|
num_classes: k,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn predict(&self, x: &M) -> M::RowVector {
|
pub fn predict(&self, x: &M) -> M::RowVector {
|
||||||
@@ -227,9 +218,12 @@ impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||||
let y_hat: Vec<T> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
|
let y_hat: Vec<T> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
result.set(0, i, self.classes[if y_hat[i].sigmoid() > T::half() { 1 } else { 0 }]);
|
result.set(
|
||||||
|
0,
|
||||||
|
i,
|
||||||
|
self.classes[if y_hat[i].sigmoid() > T::half() { 1 } else { 0 }],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
let (nrows, _) = x.shape();
|
let (nrows, _) = x.shape();
|
||||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||||
@@ -243,21 +237,21 @@ impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn coefficients(&self) -> M {
|
pub fn coefficients(&self) -> M {
|
||||||
self.weights.slice(0..self.num_classes, 0..self.num_attributes)
|
self.weights
|
||||||
|
.slice(0..self.num_classes, 0..self.num_attributes)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn intercept(&self) -> M {
|
pub fn intercept(&self) -> M {
|
||||||
self.weights.slice(0..self.num_classes, self.num_attributes..self.num_attributes+1)
|
self.weights.slice(
|
||||||
|
0..self.num_classes,
|
||||||
|
self.num_attributes..self.num_attributes + 1,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn minimize(x0: M, objective: impl ObjectiveFunction<T, M>) -> OptimizerResult<T, M> {
|
fn minimize(x0: M, objective: impl ObjectiveFunction<T, M>) -> OptimizerResult<T, M> {
|
||||||
let f = |w: &M| -> T {
|
let f = |w: &M| -> T { objective.f(w) };
|
||||||
objective.f(w)
|
|
||||||
};
|
|
||||||
|
|
||||||
let df = |g: &mut M, w: &M| {
|
let df = |g: &mut M, w: &M| objective.df(g, w);
|
||||||
objective.df(g, w)
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut ls: Backtracking<T> = Default::default();
|
let mut ls: Backtracking<T> = Default::default();
|
||||||
ls.order = FunctionOrder::THIRD;
|
ls.order = FunctionOrder::THIRD;
|
||||||
@@ -265,19 +259,17 @@ impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
|
|
||||||
optimizer.optimize(&f, &df, &x0, &ls)
|
optimizer.optimize(&f, &df, &x0, &ls)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use ndarray::{arr1, arr2, Array1};
|
|
||||||
use crate::metrics::*;
|
use crate::metrics::*;
|
||||||
|
use ndarray::{arr1, arr2, Array1};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn multiclass_objective_f() {
|
fn multiclass_objective_f() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[1., -5.],
|
&[1., -5.],
|
||||||
&[2., 5.],
|
&[2., 5.],
|
||||||
@@ -293,7 +285,8 @@ mod tests {
|
|||||||
&[9., 5.],
|
&[9., 5.],
|
||||||
&[10., -2.],
|
&[10., -2.],
|
||||||
&[8., 2.],
|
&[8., 2.],
|
||||||
&[ 9., 0.]]);
|
&[9., 0.],
|
||||||
|
]);
|
||||||
|
|
||||||
let y = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
|
let y = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
|
||||||
|
|
||||||
@@ -301,24 +294,31 @@ mod tests {
|
|||||||
x: &x,
|
x: &x,
|
||||||
y: y,
|
y: y,
|
||||||
k: 3,
|
k: 3,
|
||||||
phantom: PhantomData
|
phantom: PhantomData,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
|
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
|
||||||
|
|
||||||
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]));
|
objective.df(
|
||||||
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]));
|
&mut g,
|
||||||
|
&DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]),
|
||||||
|
);
|
||||||
|
objective.df(
|
||||||
|
&mut g,
|
||||||
|
&DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]),
|
||||||
|
);
|
||||||
|
|
||||||
assert!((g.get(0, 0) + 33.000068218163484).abs() < std::f64::EPSILON);
|
assert!((g.get(0, 0) + 33.000068218163484).abs() < std::f64::EPSILON);
|
||||||
|
|
||||||
let f = objective.f(&DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]));
|
let f = objective.f(&DenseMatrix::vector_from_array(&[
|
||||||
|
1., 2., 3., 4., 5., 6., 7., 8., 9.,
|
||||||
|
]));
|
||||||
|
|
||||||
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
|
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn binary_objective_f() {
|
fn binary_objective_f() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[1., -5.],
|
&[1., -5.],
|
||||||
&[2., 5.],
|
&[2., 5.],
|
||||||
@@ -334,14 +334,15 @@ mod tests {
|
|||||||
&[9., 5.],
|
&[9., 5.],
|
||||||
&[10., -2.],
|
&[10., -2.],
|
||||||
&[8., 2.],
|
&[8., 2.],
|
||||||
&[ 9., 0.]]);
|
&[9., 0.],
|
||||||
|
]);
|
||||||
|
|
||||||
let y = vec![0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1];
|
let y = vec![0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1];
|
||||||
|
|
||||||
let objective = BinaryObjectiveFunction {
|
let objective = BinaryObjectiveFunction {
|
||||||
x: &x,
|
x: &x,
|
||||||
y: y,
|
y: y,
|
||||||
phantom: PhantomData
|
phantom: PhantomData,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
|
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
|
||||||
@@ -360,7 +361,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lr_fit_predict() {
|
fn lr_fit_predict() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[1., -5.],
|
&[1., -5.],
|
||||||
&[2., 5.],
|
&[2., 5.],
|
||||||
@@ -376,7 +376,8 @@ mod tests {
|
|||||||
&[9., 5.],
|
&[9., 5.],
|
||||||
&[10., -2.],
|
&[10., -2.],
|
||||||
&[8., 2.],
|
&[8., 2.],
|
||||||
&[ 9., 0.]]);
|
&[9., 0.],
|
||||||
|
]);
|
||||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y);
|
||||||
@@ -389,9 +390,10 @@ mod tests {
|
|||||||
|
|
||||||
let y_hat = lr.predict(&x);
|
let y_hat = lr.predict(&x);
|
||||||
|
|
||||||
assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
|
assert_eq!(
|
||||||
|
y_hat,
|
||||||
|
vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -411,12 +413,14 @@ mod tests {
|
|||||||
&[9., 5.],
|
&[9., 5.],
|
||||||
&[10., -2.],
|
&[10., -2.],
|
||||||
&[8., 2.],
|
&[8., 2.],
|
||||||
&[ 9., 0.]]);
|
&[9., 0.],
|
||||||
|
]);
|
||||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y);
|
||||||
|
|
||||||
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(lr, deserialized_lr);
|
assert_eq!(lr, deserialized_lr);
|
||||||
}
|
}
|
||||||
@@ -443,17 +447,22 @@ mod tests {
|
|||||||
[6.3, 3.3, 4.7, 1.6],
|
[6.3, 3.3, 4.7, 1.6],
|
||||||
[4.9, 2.4, 3.3, 1.0],
|
[4.9, 2.4, 3.3, 1.0],
|
||||||
[6.6, 2.9, 4.6, 1.3],
|
[6.6, 2.9, 4.6, 1.3],
|
||||||
[5.2, 2.7, 3.9, 1.4]]);
|
[5.2, 2.7, 3.9, 1.4],
|
||||||
let y: Array1<f64> = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]);
|
]);
|
||||||
|
let y: Array1<f64> = arr1(&[
|
||||||
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
|
]);
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y);
|
||||||
|
|
||||||
let y_hat = lr.predict(&x);
|
let y_hat = lr.predict(&x);
|
||||||
|
|
||||||
let error: f64 = y.into_iter().zip(y_hat.into_iter()).map(|(&a, &b)| (a - b).abs()).sum();
|
let error: f64 = y
|
||||||
|
.into_iter()
|
||||||
|
.zip(y_hat.into_iter())
|
||||||
|
.map(|(&a, &b)| (a - b).abs())
|
||||||
|
.sum();
|
||||||
|
|
||||||
assert!(error <= 1.0);
|
assert!(error <= 1.0);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,12 +1,11 @@
|
|||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct Euclidian {
|
pub struct Euclidian {}
|
||||||
}
|
|
||||||
|
|
||||||
impl Euclidian {
|
impl Euclidian {
|
||||||
pub fn squared_distance<T: FloatExt>(x: &Vec<T>, y: &Vec<T>) -> T {
|
pub fn squared_distance<T: FloatExt>(x: &Vec<T>, y: &Vec<T>) -> T {
|
||||||
@@ -21,18 +20,14 @@ impl Euclidian {
|
|||||||
|
|
||||||
sum
|
sum
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> Distance<Vec<T>, T> for Euclidian {
|
impl<T: FloatExt> Distance<Vec<T>, T> for Euclidian {
|
||||||
|
|
||||||
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
||||||
Euclidian::squared_distance(x, y).sqrt()
|
Euclidian::squared_distance(x, y).sqrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -46,5 +41,4 @@ mod tests {
|
|||||||
|
|
||||||
assert!((l2 - 5.19615242).abs() < 1e-8);
|
assert!((l2 - 5.19615242).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,15 +1,13 @@
|
|||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct Hamming {
|
pub struct Hamming {}
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: PartialEq, F: FloatExt> Distance<Vec<T>, F> for Hamming {
|
impl<T: PartialEq, F: FloatExt> Distance<Vec<T>, F> for Hamming {
|
||||||
|
|
||||||
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> F {
|
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> F {
|
||||||
if x.len() != y.len() {
|
if x.len() != y.len() {
|
||||||
panic!("Input vector sizes are different");
|
panic!("Input vector sizes are different");
|
||||||
@@ -24,10 +22,8 @@ impl<T: PartialEq, F: FloatExt> Distance<Vec<T>, F> for Hamming {
|
|||||||
|
|
||||||
F::from_i64(dist).unwrap() / F::from_usize(x.len()).unwrap()
|
F::from_i64(dist).unwrap() / F::from_usize(x.len()).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -41,5 +37,4 @@ mod tests {
|
|||||||
|
|
||||||
assert!((h - 0.42857142).abs() < 1e-8);
|
assert!((h - 0.42857142).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
@@ -13,7 +13,7 @@ use crate::linalg::Matrix;
|
|||||||
pub struct Mahalanobis<T: FloatExt, M: Matrix<T>> {
|
pub struct Mahalanobis<T: FloatExt, M: Matrix<T>> {
|
||||||
pub sigma: M,
|
pub sigma: M,
|
||||||
pub sigmaInv: M,
|
pub sigmaInv: M,
|
||||||
t: PhantomData<T>
|
t: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: Matrix<T>> Mahalanobis<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> Mahalanobis<T, M> {
|
||||||
@@ -23,7 +23,7 @@ impl<T: FloatExt, M: Matrix<T>> Mahalanobis<T, M> {
|
|||||||
Mahalanobis {
|
Mahalanobis {
|
||||||
sigma: sigma,
|
sigma: sigma,
|
||||||
sigmaInv: sigmaInv,
|
sigmaInv: sigmaInv,
|
||||||
t: PhantomData
|
t: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,21 +33,30 @@ impl<T: FloatExt, M: Matrix<T>> Mahalanobis<T, M> {
|
|||||||
Mahalanobis {
|
Mahalanobis {
|
||||||
sigma: sigma,
|
sigma: sigma,
|
||||||
sigmaInv: sigmaInv,
|
sigmaInv: sigmaInv,
|
||||||
t: PhantomData
|
t: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
|
||||||
|
|
||||||
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
||||||
let (nrows, ncols) = self.sigma.shape();
|
let (nrows, ncols) = self.sigma.shape();
|
||||||
if x.len() != nrows {
|
if x.len() != nrows {
|
||||||
panic!("Array x[{}] has different dimension with Sigma[{}][{}].", x.len(), nrows, ncols);
|
panic!(
|
||||||
|
"Array x[{}] has different dimension with Sigma[{}][{}].",
|
||||||
|
x.len(),
|
||||||
|
nrows,
|
||||||
|
ncols
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if y.len() != nrows {
|
if y.len() != nrows {
|
||||||
panic!("Array y[{}] has different dimension with Sigma[{}][{}].", y.len(), nrows, ncols);
|
panic!(
|
||||||
|
"Array y[{}] has different dimension with Sigma[{}][{}].",
|
||||||
|
y.len(),
|
||||||
|
nrows,
|
||||||
|
ncols
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("{}", self.sigmaInv);
|
println!("{}", self.sigmaInv);
|
||||||
@@ -68,10 +77,8 @@ impl<T: FloatExt, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
|
|||||||
|
|
||||||
s.sqrt()
|
s.sqrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -84,7 +91,8 @@ mod tests {
|
|||||||
&[66., 570., 33.],
|
&[66., 570., 33.],
|
||||||
&[68., 590., 37.],
|
&[68., 590., 37.],
|
||||||
&[69., 660., 46.],
|
&[69., 660., 46.],
|
||||||
&[ 73., 600., 55.]]);
|
&[73., 600., 55.],
|
||||||
|
]);
|
||||||
|
|
||||||
let a = data.column_mean();
|
let a = data.column_mean();
|
||||||
let b = vec![66., 640., 44.];
|
let b = vec![66., 640., 44.];
|
||||||
@@ -93,5 +101,4 @@ mod tests {
|
|||||||
|
|
||||||
println!("{}", mahalanobis.distance(&a, &b));
|
println!("{}", mahalanobis.distance(&a, &b));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,15 +1,13 @@
|
|||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct Manhattan {
|
pub struct Manhattan {}
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: FloatExt> Distance<Vec<T>, T> for Manhattan {
|
impl<T: FloatExt> Distance<Vec<T>, T> for Manhattan {
|
||||||
|
|
||||||
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
||||||
if x.len() != y.len() {
|
if x.len() != y.len() {
|
||||||
panic!("Input vector sizes are different");
|
panic!("Input vector sizes are different");
|
||||||
@@ -22,10 +20,8 @@ impl<T: FloatExt> Distance<Vec<T>, T> for Manhattan {
|
|||||||
|
|
||||||
dist
|
dist
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -39,5 +35,4 @@ mod tests {
|
|||||||
|
|
||||||
assert!((l1 - 9.0).abs() < 1e-8);
|
assert!((l1 - 9.0).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
@@ -6,11 +6,10 @@ use super::Distance;
|
|||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct Minkowski<T: FloatExt> {
|
pub struct Minkowski<T: FloatExt> {
|
||||||
pub p: T
|
pub p: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> Distance<Vec<T>, T> for Minkowski<T> {
|
impl<T: FloatExt> Distance<Vec<T>, T> for Minkowski<T> {
|
||||||
|
|
||||||
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
||||||
if x.len() != y.len() {
|
if x.len() != y.len() {
|
||||||
panic!("Input vector sizes are different");
|
panic!("Input vector sizes are different");
|
||||||
@@ -27,10 +26,8 @@ impl<T: FloatExt> Distance<Vec<T>, T> for Minkowski<T> {
|
|||||||
|
|
||||||
dist.powf(T::one() / self.p)
|
dist.powf(T::one() / self.p)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -57,7 +54,4 @@ mod tests {
|
|||||||
|
|
||||||
let _: f64 = Minkowski { p: 0.0 }.distance(&a, &b);
|
let _: f64 = Minkowski { p: 0.0 }.distance(&a, &b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
pub mod euclidian;
|
pub mod euclidian;
|
||||||
pub mod minkowski;
|
|
||||||
pub mod manhattan;
|
|
||||||
pub mod hamming;
|
pub mod hamming;
|
||||||
pub mod mahalanobis;
|
pub mod mahalanobis;
|
||||||
|
pub mod manhattan;
|
||||||
|
pub mod minkowski;
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
@@ -10,8 +10,7 @@ pub trait Distance<T, F: FloatExt>{
|
|||||||
fn distance(&self, a: &T, b: &T) -> F;
|
fn distance(&self, a: &T, b: &T) -> F;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Distances{
|
pub struct Distances {}
|
||||||
}
|
|
||||||
|
|
||||||
impl Distances {
|
impl Distances {
|
||||||
pub fn euclidian() -> euclidian::Euclidian {
|
pub fn euclidian() -> euclidian::Euclidian {
|
||||||
|
|||||||
+3
-11
@@ -1,9 +1,8 @@
|
|||||||
use std::fmt::{Debug, Display};
|
|
||||||
use num_traits::{Float, FromPrimitive};
|
use num_traits::{Float, FromPrimitive};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
use std::fmt::{Debug, Display};
|
||||||
|
|
||||||
pub trait FloatExt: Float + FromPrimitive + Debug + Display + Copy {
|
pub trait FloatExt: Float + FromPrimitive + Debug + Display + Copy {
|
||||||
|
|
||||||
fn copysign(self, sign: Self) -> Self;
|
fn copysign(self, sign: Self) -> Self;
|
||||||
|
|
||||||
fn ln_1pe(self) -> Self;
|
fn ln_1pe(self) -> Self;
|
||||||
@@ -15,7 +14,6 @@ pub trait FloatExt: Float + FromPrimitive + Debug + Display + Copy {
|
|||||||
fn two() -> Self;
|
fn two() -> Self;
|
||||||
|
|
||||||
fn half() -> Self;
|
fn half() -> Self;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FloatExt for f64 {
|
impl FloatExt for f64 {
|
||||||
@@ -29,19 +27,16 @@ impl FloatExt for f64 {
|
|||||||
} else {
|
} else {
|
||||||
return self.exp().ln_1p();
|
return self.exp().ln_1p();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sigmoid(self) -> f64 {
|
fn sigmoid(self) -> f64 {
|
||||||
|
|
||||||
if self < -40. {
|
if self < -40. {
|
||||||
return 0.;
|
return 0.;
|
||||||
} else if self > 40. {
|
} else if self > 40. {
|
||||||
return 1.;
|
return 1.;
|
||||||
} else {
|
} else {
|
||||||
return 1. / (1. + f64::exp(-self))
|
return 1. / (1. + f64::exp(-self));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand() -> f64 {
|
fn rand() -> f64 {
|
||||||
@@ -69,19 +64,16 @@ impl FloatExt for f32 {
|
|||||||
} else {
|
} else {
|
||||||
return self.exp().ln_1p();
|
return self.exp().ln_1p();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sigmoid(self) -> f32 {
|
fn sigmoid(self) -> f32 {
|
||||||
|
|
||||||
if self < -40. {
|
if self < -40. {
|
||||||
return 0.;
|
return 0.;
|
||||||
} else if self > 40. {
|
} else if self > 40. {
|
||||||
return 1.;
|
return 1.;
|
||||||
} else {
|
} else {
|
||||||
return 1. / (1. + f32::exp(-self))
|
return 1. / (1. + f32::exp(-self));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand() -> f32 {
|
fn rand() -> f32 {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct Accuracy {}
|
pub struct Accuracy {}
|
||||||
@@ -9,7 +9,11 @@ pub struct Accuracy{}
|
|||||||
impl Accuracy {
|
impl Accuracy {
|
||||||
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
|
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
|
||||||
if y_true.len() != y_prod.len() {
|
if y_true.len() != y_prod.len() {
|
||||||
panic!("The vector sizes don't match: {} != {}", y_true.len(), y_prod.len());
|
panic!(
|
||||||
|
"The vector sizes don't match: {} != {}",
|
||||||
|
y_true.len(),
|
||||||
|
y_prod.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let n = y_true.len();
|
let n = y_true.len();
|
||||||
@@ -23,7 +27,6 @@ impl Accuracy {
|
|||||||
|
|
||||||
T::from_i64(positive).unwrap() / T::from_usize(n).unwrap()
|
T::from_i64(positive).unwrap() / T::from_usize(n).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -41,5 +44,4 @@ mod tests {
|
|||||||
assert!((score1 - 0.5).abs() < 1e-8);
|
assert!((score1 - 0.5).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
+2
-2
@@ -1,9 +1,9 @@
|
|||||||
pub mod accuracy;
|
pub mod accuracy;
|
||||||
pub mod recall;
|
|
||||||
pub mod precision;
|
pub mod precision;
|
||||||
|
pub mod recall;
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
pub struct ClassificationMetrics {}
|
pub struct ClassificationMetrics {}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct Precision {}
|
pub struct Precision {}
|
||||||
@@ -9,7 +9,11 @@ pub struct Precision{}
|
|||||||
impl Precision {
|
impl Precision {
|
||||||
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
|
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
|
||||||
if y_true.len() != y_prod.len() {
|
if y_true.len() != y_prod.len() {
|
||||||
panic!("The vector sizes don't match: {} != {}", y_true.len(), y_prod.len());
|
panic!(
|
||||||
|
"The vector sizes don't match: {} != {}",
|
||||||
|
y_true.len(),
|
||||||
|
y_prod.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut tp = 0;
|
let mut tp = 0;
|
||||||
@@ -17,11 +21,17 @@ impl Precision {
|
|||||||
let n = y_true.len();
|
let n = y_true.len();
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
||||||
panic!("Precision can only be applied to binary classification: {}", y_true.get(i));
|
panic!(
|
||||||
|
"Precision can only be applied to binary classification: {}",
|
||||||
|
y_true.get(i)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() {
|
if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() {
|
||||||
panic!("Precision can only be applied to binary classification: {}", y_prod.get(i));
|
panic!(
|
||||||
|
"Precision can only be applied to binary classification: {}",
|
||||||
|
y_prod.get(i)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if y_prod.get(i) == T::one() {
|
if y_prod.get(i) == T::one() {
|
||||||
@@ -35,7 +45,6 @@ impl Precision {
|
|||||||
|
|
||||||
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -53,5 +62,4 @@ mod tests {
|
|||||||
assert!((score1 - 0.5).abs() < 1e-8);
|
assert!((score1 - 0.5).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
+15
-7
@@ -1,7 +1,7 @@
|
|||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct Recall {}
|
pub struct Recall {}
|
||||||
@@ -9,7 +9,11 @@ pub struct Recall{}
|
|||||||
impl Recall {
|
impl Recall {
|
||||||
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
|
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
|
||||||
if y_true.len() != y_prod.len() {
|
if y_true.len() != y_prod.len() {
|
||||||
panic!("The vector sizes don't match: {} != {}", y_true.len(), y_prod.len());
|
panic!(
|
||||||
|
"The vector sizes don't match: {} != {}",
|
||||||
|
y_true.len(),
|
||||||
|
y_prod.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut tp = 0;
|
let mut tp = 0;
|
||||||
@@ -17,11 +21,17 @@ impl Recall {
|
|||||||
let n = y_true.len();
|
let n = y_true.len();
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
||||||
panic!("Recall can only be applied to binary classification: {}", y_true.get(i));
|
panic!(
|
||||||
|
"Recall can only be applied to binary classification: {}",
|
||||||
|
y_true.get(i)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() {
|
if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() {
|
||||||
panic!("Recall can only be applied to binary classification: {}", y_prod.get(i));
|
panic!(
|
||||||
|
"Recall can only be applied to binary classification: {}",
|
||||||
|
y_prod.get(i)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if y_true.get(i) == T::one() {
|
if y_true.get(i) == T::one() {
|
||||||
@@ -35,7 +45,6 @@ impl Recall {
|
|||||||
|
|
||||||
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -53,5 +62,4 @@ mod tests {
|
|||||||
assert!((score1 - 0.5).abs() < 1e-8);
|
assert!((score1 - 0.5).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
+64
-46
@@ -1,66 +1,70 @@
|
|||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::math::distance::Distance;
|
|
||||||
use crate::linalg::{Matrix, row_iter};
|
|
||||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
|
||||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||||
|
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||||
|
use crate::linalg::{row_iter, Matrix};
|
||||||
|
use crate::math::distance::Distance;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
knn_algorithm: KNNAlgorithmV<T, D>,
|
knn_algorithm: KNNAlgorithmV<T, D>,
|
||||||
k: usize
|
k: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum KNNAlgorithmName {
|
pub enum KNNAlgorithmName {
|
||||||
LinearSearch,
|
LinearSearch,
|
||||||
CoverTree
|
CoverTree,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub enum KNNAlgorithmV<T: FloatExt, D: Distance<Vec<T>, T>> {
|
pub enum KNNAlgorithmV<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||||
CoverTree(CoverTree<Vec<T>, T, D>)
|
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl KNNAlgorithmName {
|
impl KNNAlgorithmName {
|
||||||
|
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(
|
||||||
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(&self, data: Vec<Vec<T>>, distance: D) -> KNNAlgorithmV<T, D> {
|
&self,
|
||||||
|
data: Vec<Vec<T>>,
|
||||||
|
distance: D,
|
||||||
|
) -> KNNAlgorithmV<T, D> {
|
||||||
match *self {
|
match *self {
|
||||||
KNNAlgorithmName::LinearSearch => KNNAlgorithmV::LinearSearch(LinearKNNSearch::new(data, distance)),
|
KNNAlgorithmName::LinearSearch => {
|
||||||
|
KNNAlgorithmV::LinearSearch(LinearKNNSearch::new(data, distance))
|
||||||
|
}
|
||||||
KNNAlgorithmName::CoverTree => KNNAlgorithmV::CoverTree(CoverTree::new(data, distance)),
|
KNNAlgorithmName::CoverTree => KNNAlgorithmV::CoverTree(CoverTree::new(data, distance)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithmV<T, D> {
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithmV<T, D> {
|
||||||
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize> {
|
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize> {
|
||||||
match *self {
|
match *self {
|
||||||
KNNAlgorithmV::LinearSearch(ref linear) => linear.find(from, k),
|
KNNAlgorithmV::LinearSearch(ref linear) => linear.find(from, k),
|
||||||
KNNAlgorithmV::CoverTree(ref cover) => cover.find(from, k)
|
KNNAlgorithmV::CoverTree(ref cover) => cover.find(from, k),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.classes.len() != other.classes.len() ||
|
if self.classes.len() != other.classes.len()
|
||||||
self.k != other.k ||
|
|| self.k != other.k
|
||||||
self.y.len() != other.y.len() {
|
|| self.y.len() != other.y.len()
|
||||||
return false
|
{
|
||||||
|
return false;
|
||||||
} else {
|
} else {
|
||||||
for i in 0..self.classes.len() {
|
for i in 0..self.classes.len() {
|
||||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i in 0..self.y.len() {
|
for i in 0..self.y.len() {
|
||||||
if self.y[i] != other.y[i] {
|
if self.y[i] != other.y[i] {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
true
|
true
|
||||||
@@ -69,9 +73,13 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||||
|
pub fn fit<M: Matrix<T>>(
|
||||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, k: usize, distance: D, algorithm: KNNAlgorithmName) -> KNNClassifier<T, D> {
|
x: &M,
|
||||||
|
y: &M::RowVector,
|
||||||
|
k: usize,
|
||||||
|
distance: D,
|
||||||
|
algorithm: KNNAlgorithmName,
|
||||||
|
) -> KNNClassifier<T, D> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
|
|
||||||
let (_, y_n) = y_m.shape();
|
let (_, y_n) = y_m.shape();
|
||||||
@@ -87,24 +95,35 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
|||||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(x_n == y_n, format!("Size of x should equal size of y; |x|=[{}], |y|=[{}]", x_n, y_n));
|
assert!(
|
||||||
|
x_n == y_n,
|
||||||
|
format!(
|
||||||
|
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||||
|
x_n, y_n
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
|
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
|
||||||
|
|
||||||
KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: algorithm.fit(data, distance)}
|
KNNClassifier {
|
||||||
|
classes: classes,
|
||||||
|
y: yi,
|
||||||
|
k: k,
|
||||||
|
knn_algorithm: algorithm.fit(data, distance),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||||
let mut result = M::zeros(1, x.shape().0);
|
let mut result = M::zeros(1, x.shape().0);
|
||||||
|
|
||||||
row_iter(x).enumerate().for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)]));
|
row_iter(x)
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)]));
|
||||||
|
|
||||||
result.to_row_vector()
|
result.to_row_vector()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn predict_for_row(&self, x: Vec<T>) -> usize {
|
fn predict_for_row(&self, x: Vec<T>) -> usize {
|
||||||
|
|
||||||
let idxs = self.knn_algorithm.find(&x, self.k);
|
let idxs = self.knn_algorithm.find(&x, self.k);
|
||||||
let mut c = vec![0; self.classes.len()];
|
let mut c = vec![0; self.classes.len()];
|
||||||
let mut max_c = 0;
|
let mut max_c = 0;
|
||||||
@@ -118,27 +137,26 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
max_i
|
max_i
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::math::distance::Distances;
|
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
use crate::math::distance::Distances;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_fit_predict() {
|
fn knn_fit_predict() {
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
&[1., 2.],
|
|
||||||
&[3., 4.],
|
|
||||||
&[5., 6.],
|
|
||||||
&[7., 8.],
|
|
||||||
&[9., 10.]]);
|
|
||||||
let y = vec![2., 2., 2., 3., 3.];
|
let y = vec![2., 2., 2., 3., 3.];
|
||||||
let knn = KNNClassifier::fit(&x, &y, 3, Distances::euclidian(), KNNAlgorithmName::LinearSearch);
|
let knn = KNNClassifier::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
3,
|
||||||
|
Distances::euclidian(),
|
||||||
|
KNNAlgorithmName::LinearSearch,
|
||||||
|
);
|
||||||
let r = knn.predict(&x);
|
let r = knn.predict(&x);
|
||||||
assert_eq!(5, Vec::len(&r));
|
assert_eq!(5, Vec::len(&r));
|
||||||
assert_eq!(y.to_vec(), r);
|
assert_eq!(y.to_vec(), r);
|
||||||
@@ -146,19 +164,19 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
&[1., 2.],
|
|
||||||
&[3., 4.],
|
|
||||||
&[5., 6.],
|
|
||||||
&[7., 8.],
|
|
||||||
&[9., 10.]]);
|
|
||||||
let y = vec![2., 2., 2., 3., 3.];
|
let y = vec![2., 2., 2., 3., 3.];
|
||||||
|
|
||||||
let knn = KNNClassifier::fit(&x, &y, 3, Distances::euclidian(), KNNAlgorithmName::CoverTree);
|
let knn = KNNClassifier::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
3,
|
||||||
|
Distances::euclidian(),
|
||||||
|
KNNAlgorithmName::CoverTree,
|
||||||
|
);
|
||||||
|
|
||||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(knn, deserialized_knn);
|
assert_eq!(knn, deserialized_knn);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::optimization::{F, DF};
|
use crate::math::num::FloatExt;
|
||||||
use crate::optimization::line_search::LineSearchMethod;
|
|
||||||
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||||
|
use crate::optimization::line_search::LineSearchMethod;
|
||||||
|
use crate::optimization::{DF, F};
|
||||||
|
|
||||||
pub struct GradientDescent<T: FloatExt> {
|
pub struct GradientDescent<T: FloatExt> {
|
||||||
pub max_iter: usize,
|
pub max_iter: usize,
|
||||||
pub g_rtol: T,
|
pub g_rtol: T,
|
||||||
pub g_atol: T
|
pub g_atol: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> Default for GradientDescent<T> {
|
impl<T: FloatExt> Default for GradientDescent<T> {
|
||||||
@@ -17,16 +17,19 @@ impl<T: FloatExt> Default for GradientDescent<T> {
|
|||||||
GradientDescent {
|
GradientDescent {
|
||||||
max_iter: 10000,
|
max_iter: 10000,
|
||||||
g_rtol: T::epsilon().sqrt(),
|
g_rtol: T::epsilon().sqrt(),
|
||||||
g_atol: T::epsilon()
|
g_atol: T::epsilon(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> FirstOrderOptimizer<T> for GradientDescent<T>
|
impl<T: FloatExt> FirstOrderOptimizer<T> for GradientDescent<T> {
|
||||||
{
|
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
|
||||||
|
&self,
|
||||||
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &'a F<T, X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<T, X> {
|
f: &'a F<T, X>,
|
||||||
|
df: &'a DF<X>,
|
||||||
|
x0: &X,
|
||||||
|
ls: &'a LS,
|
||||||
|
) -> OptimizerResult<T, X> {
|
||||||
let mut x = x0.clone();
|
let mut x = x0.clone();
|
||||||
let mut fx = f(&x);
|
let mut fx = f(&x);
|
||||||
|
|
||||||
@@ -73,7 +76,7 @@ impl<T: FloatExt> FirstOrderOptimizer<T> for GradientDescent<T>
|
|||||||
OptimizerResult {
|
OptimizerResult {
|
||||||
x: x,
|
x: x,
|
||||||
f_x: f_x,
|
f_x: f_x,
|
||||||
iterations: iter
|
iterations: iter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -87,14 +90,18 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn gradient_descent() {
|
fn gradient_descent() {
|
||||||
|
|
||||||
let x0 = DenseMatrix::vector_from_array(&[-1., 1.]);
|
let x0 = DenseMatrix::vector_from_array(&[-1., 1.]);
|
||||||
let f = |x: &DenseMatrix<f64>| {
|
let f = |x: &DenseMatrix<f64>| {
|
||||||
(1.0 - x.get(0, 0)).powf(2.) + 100.0 * (x.get(0, 1) - x.get(0, 0).powf(2.)).powf(2.)
|
(1.0 - x.get(0, 0)).powf(2.) + 100.0 * (x.get(0, 1) - x.get(0, 0).powf(2.)).powf(2.)
|
||||||
};
|
};
|
||||||
|
|
||||||
let df = |g: &mut DenseMatrix<f64>, x: &DenseMatrix<f64>| {
|
let df = |g: &mut DenseMatrix<f64>, x: &DenseMatrix<f64>| {
|
||||||
g.set(0, 0, -2. * (1. - x.get(0, 0)) - 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0));
|
g.set(
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
-2. * (1. - x.get(0, 0))
|
||||||
|
- 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0),
|
||||||
|
);
|
||||||
g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.)));
|
g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.)));
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -107,7 +114,5 @@ mod tests {
|
|||||||
assert!((result.f_x - 0.0).abs() < 1e-5);
|
assert!((result.f_x - 0.0).abs() < 1e-5);
|
||||||
assert!((result.x.get(0, 0) - 1.0).abs() < 1e-2);
|
assert!((result.x.get(0, 0) - 1.0).abs() < 1e-2);
|
||||||
assert!((result.x.get(0, 1) - 1.0).abs() < 1e-2);
|
assert!((result.x.get(0, 1) - 1.0).abs() < 1e-2);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::optimization::{F, DF};
|
use crate::math::num::FloatExt;
|
||||||
use crate::optimization::line_search::LineSearchMethod;
|
|
||||||
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||||
|
use crate::optimization::line_search::LineSearchMethod;
|
||||||
|
use crate::optimization::{DF, F};
|
||||||
|
|
||||||
pub struct LBFGS<T: FloatExt> {
|
pub struct LBFGS<T: FloatExt> {
|
||||||
pub max_iter: usize,
|
pub max_iter: usize,
|
||||||
@@ -16,7 +16,7 @@ pub struct LBFGS<T: FloatExt> {
|
|||||||
pub f_abstol: T,
|
pub f_abstol: T,
|
||||||
pub f_reltol: T,
|
pub f_reltol: T,
|
||||||
pub successive_f_tol: usize,
|
pub successive_f_tol: usize,
|
||||||
pub m: usize
|
pub m: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> Default for LBFGS<T> {
|
impl<T: FloatExt> Default for LBFGS<T> {
|
||||||
@@ -30,15 +30,13 @@ impl<T: FloatExt> Default for LBFGS<T> {
|
|||||||
f_abstol: T::zero(),
|
f_abstol: T::zero(),
|
||||||
f_reltol: T::zero(),
|
f_reltol: T::zero(),
|
||||||
successive_f_tol: 1,
|
successive_f_tol: 1,
|
||||||
m: 10
|
m: 10,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> LBFGS<T> {
|
impl<T: FloatExt> LBFGS<T> {
|
||||||
|
|
||||||
fn two_loops<X: Matrix<T>>(&self, state: &mut LBFGSState<T, X>) {
|
fn two_loops<X: Matrix<T>>(&self, state: &mut LBFGSState<T, X>) {
|
||||||
|
|
||||||
let lower = state.iteration.max(self.m) - self.m;
|
let lower = state.iteration.max(self.m) - self.m;
|
||||||
let upper = state.iteration;
|
let upper = state.iteration;
|
||||||
|
|
||||||
@@ -49,7 +47,9 @@ impl<T: FloatExt> LBFGS<T> {
|
|||||||
let dgi = &state.dg_history[i];
|
let dgi = &state.dg_history[i];
|
||||||
let dxi = &state.dx_history[i];
|
let dxi = &state.dx_history[i];
|
||||||
state.twoloop_alpha[i] = state.rho[i] * dxi.vector_dot(&state.twoloop_q);
|
state.twoloop_alpha[i] = state.rho[i] * dxi.vector_dot(&state.twoloop_q);
|
||||||
state.twoloop_q.sub_mut(&dgi.mul_scalar(state.twoloop_alpha[i]));
|
state
|
||||||
|
.twoloop_q
|
||||||
|
.sub_mut(&dgi.mul_scalar(state.twoloop_alpha[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.iteration > 0 {
|
if state.iteration > 0 {
|
||||||
@@ -67,11 +67,12 @@ impl<T: FloatExt> LBFGS<T> {
|
|||||||
let dgi = &state.dg_history[i];
|
let dgi = &state.dg_history[i];
|
||||||
let dxi = &state.dx_history[i];
|
let dxi = &state.dx_history[i];
|
||||||
let beta = state.rho[i] * dgi.vector_dot(&state.s);
|
let beta = state.rho[i] * dgi.vector_dot(&state.s);
|
||||||
state.s.add_mut(&dxi.mul_scalar(state.twoloop_alpha[i] - beta));
|
state
|
||||||
|
.s
|
||||||
|
.add_mut(&dxi.mul_scalar(state.twoloop_alpha[i] - beta));
|
||||||
}
|
}
|
||||||
|
|
||||||
state.s.mul_scalar_mut(-T::one());
|
state.s.mul_scalar_mut(-T::one());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn init_state<X: Matrix<T>>(&self, x: &X) -> LBFGSState<T, X> {
|
fn init_state<X: Matrix<T>>(&self, x: &X) -> LBFGSState<T, X> {
|
||||||
@@ -93,11 +94,17 @@ impl<T: FloatExt> LBFGS<T> {
|
|||||||
iteration: 0,
|
iteration: 0,
|
||||||
counter_f_tol: 0,
|
counter_f_tol: 0,
|
||||||
s: x.clone(),
|
s: x.clone(),
|
||||||
alpha: T::one()
|
alpha: T::one(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_state<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &'a F<T, X>, df: &'a DF<X>, ls: &'a LS, state: &mut LBFGSState<T, X>) {
|
fn update_state<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
|
||||||
|
&self,
|
||||||
|
f: &'a F<T, X>,
|
||||||
|
df: &'a DF<X>,
|
||||||
|
ls: &'a LS,
|
||||||
|
state: &mut LBFGSState<T, X>,
|
||||||
|
) {
|
||||||
self.two_loops(state);
|
self.two_loops(state);
|
||||||
|
|
||||||
df(&mut state.x_df_prev, &state.x);
|
df(&mut state.x_df_prev, &state.x);
|
||||||
@@ -127,7 +134,6 @@ impl<T: FloatExt> LBFGS<T> {
|
|||||||
state.x.add_mut(&state.dx);
|
state.x.add_mut(&state.dx);
|
||||||
state.x_f = f(&state.x);
|
state.x_f = f(&state.x);
|
||||||
df(&mut state.x_df, &state.x);
|
df(&mut state.x_df, &state.x);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assess_convergence<X: Matrix<T>>(&self, state: &mut LBFGSState<T, X>) -> bool {
|
fn assess_convergence<X: Matrix<T>>(&self, state: &mut LBFGSState<T, X>) -> bool {
|
||||||
@@ -186,13 +192,17 @@ struct LBFGSState<T: FloatExt, X: Matrix<T>> {
|
|||||||
iteration: usize,
|
iteration: usize,
|
||||||
counter_f_tol: usize,
|
counter_f_tol: usize,
|
||||||
s: X,
|
s: X,
|
||||||
alpha: T
|
alpha: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> FirstOrderOptimizer<T> for LBFGS<T> {
|
impl<T: FloatExt> FirstOrderOptimizer<T> for LBFGS<T> {
|
||||||
|
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
|
||||||
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &F<T, X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<T, X> {
|
&self,
|
||||||
|
f: &F<T, X>,
|
||||||
|
df: &'a DF<X>,
|
||||||
|
x0: &X,
|
||||||
|
ls: &'a LS,
|
||||||
|
) -> OptimizerResult<T, X> {
|
||||||
let mut state = self.init_state(x0);
|
let mut state = self.init_state(x0);
|
||||||
|
|
||||||
df(&mut state.x_df, &x0);
|
df(&mut state.x_df, &x0);
|
||||||
@@ -202,7 +212,6 @@ impl<T: FloatExt> FirstOrderOptimizer<T> for LBFGS<T> {
|
|||||||
let stopped = false;
|
let stopped = false;
|
||||||
|
|
||||||
while !converged && !stopped && state.iteration < self.max_iter {
|
while !converged && !stopped && state.iteration < self.max_iter {
|
||||||
|
|
||||||
self.update_state(f, df, ls, &mut state);
|
self.update_state(f, df, ls, &mut state);
|
||||||
|
|
||||||
converged = self.assess_convergence(&mut state);
|
converged = self.assess_convergence(&mut state);
|
||||||
@@ -212,17 +221,14 @@ impl<T: FloatExt> FirstOrderOptimizer<T> for LBFGS<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
state.iteration += 1;
|
state.iteration += 1;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
OptimizerResult {
|
OptimizerResult {
|
||||||
x: state.x,
|
x: state.x,
|
||||||
f_x: state.x_f,
|
f_x: state.x_f,
|
||||||
iterations: state.iteration
|
iterations: state.iteration,
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -240,7 +246,12 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let df = |g: &mut DenseMatrix<f64>, x: &DenseMatrix<f64>| {
|
let df = |g: &mut DenseMatrix<f64>, x: &DenseMatrix<f64>| {
|
||||||
g.set(0, 0, -2. * (1. - x.get(0, 0)) - 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0));
|
g.set(
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
-2. * (1. - x.get(0, 0))
|
||||||
|
- 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0),
|
||||||
|
);
|
||||||
g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.)));
|
g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.)));
|
||||||
};
|
};
|
||||||
let mut ls: Backtracking<f64> = Default::default();
|
let mut ls: Backtracking<f64> = Default::default();
|
||||||
|
|||||||
@@ -1,22 +1,27 @@
|
|||||||
pub mod lbfgs;
|
|
||||||
pub mod gradient_descent;
|
pub mod gradient_descent;
|
||||||
|
pub mod lbfgs;
|
||||||
|
|
||||||
use std::clone::Clone;
|
use std::clone::Clone;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
use crate::optimization::line_search::LineSearchMethod;
|
use crate::optimization::line_search::LineSearchMethod;
|
||||||
use crate::optimization::{F, DF};
|
use crate::optimization::{DF, F};
|
||||||
|
|
||||||
pub trait FirstOrderOptimizer<T: FloatExt> {
|
pub trait FirstOrderOptimizer<T: FloatExt> {
|
||||||
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &F<T, X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<T, X>;
|
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
|
||||||
|
&self,
|
||||||
|
f: &F<T, X>,
|
||||||
|
df: &'a DF<X>,
|
||||||
|
x0: &X,
|
||||||
|
ls: &'a LS,
|
||||||
|
) -> OptimizerResult<T, X>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct OptimizerResult<T: FloatExt, X: Matrix<T>>
|
pub struct OptimizerResult<T: FloatExt, X: Matrix<T>> {
|
||||||
{
|
|
||||||
pub x: X,
|
pub x: X,
|
||||||
pub f_x: T,
|
pub f_x: T,
|
||||||
pub iterations: usize
|
pub iterations: usize,
|
||||||
}
|
}
|
||||||
@@ -1,14 +1,21 @@
|
|||||||
use num_traits::Float;
|
|
||||||
use crate::optimization::FunctionOrder;
|
use crate::optimization::FunctionOrder;
|
||||||
|
use num_traits::Float;
|
||||||
|
|
||||||
pub trait LineSearchMethod<T: Float> {
|
pub trait LineSearchMethod<T: Float> {
|
||||||
fn search<'a>(&self, f: &(dyn Fn(T) -> T), df: &(dyn Fn(T) -> T), alpha: T, f0: T, df0: T) -> LineSearchResult<T>;
|
fn search<'a>(
|
||||||
|
&self,
|
||||||
|
f: &(dyn Fn(T) -> T),
|
||||||
|
df: &(dyn Fn(T) -> T),
|
||||||
|
alpha: T,
|
||||||
|
f0: T,
|
||||||
|
df0: T,
|
||||||
|
) -> LineSearchResult<T>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct LineSearchResult<T: Float> {
|
pub struct LineSearchResult<T: Float> {
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
pub f_x: T
|
pub f_x: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Backtracking<T: Float> {
|
pub struct Backtracking<T: Float> {
|
||||||
@@ -17,7 +24,7 @@ pub struct Backtracking<T: Float> {
|
|||||||
pub max_infinity_iterations: usize,
|
pub max_infinity_iterations: usize,
|
||||||
pub phi: T,
|
pub phi: T,
|
||||||
pub plo: T,
|
pub plo: T,
|
||||||
pub order: FunctionOrder
|
pub order: FunctionOrder,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Float> Default for Backtracking<T> {
|
impl<T: Float> Default for Backtracking<T> {
|
||||||
@@ -28,15 +35,20 @@ impl<T: Float> Default for Backtracking<T> {
|
|||||||
max_infinity_iterations: (-T::epsilon().log2()).to_usize().unwrap(),
|
max_infinity_iterations: (-T::epsilon().log2()).to_usize().unwrap(),
|
||||||
phi: T::from(0.5).unwrap(),
|
phi: T::from(0.5).unwrap(),
|
||||||
plo: T::from(0.1).unwrap(),
|
plo: T::from(0.1).unwrap(),
|
||||||
order: FunctionOrder::SECOND
|
order: FunctionOrder::SECOND,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
||||||
|
fn search<'a>(
|
||||||
fn search<'a>(&self, f: &(dyn Fn(T) -> T), _: &(dyn Fn(T) -> T), alpha: T, f0: T, df0: T) -> LineSearchResult<T> {
|
&self,
|
||||||
|
f: &(dyn Fn(T) -> T),
|
||||||
|
_: &(dyn Fn(T) -> T),
|
||||||
|
alpha: T,
|
||||||
|
f0: T,
|
||||||
|
df0: T,
|
||||||
|
) -> LineSearchResult<T> {
|
||||||
let two = T::from(2.).unwrap();
|
let two = T::from(2.).unwrap();
|
||||||
let three = T::from(3.).unwrap();
|
let three = T::from(3.).unwrap();
|
||||||
|
|
||||||
@@ -62,14 +74,15 @@ impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
|||||||
let a_tmp;
|
let a_tmp;
|
||||||
|
|
||||||
if self.order == FunctionOrder::SECOND || iteration == 0 {
|
if self.order == FunctionOrder::SECOND || iteration == 0 {
|
||||||
|
|
||||||
a_tmp = -(df0 * a2.powf(two)) / (two * (fx1 - f0 - df0 * a2))
|
a_tmp = -(df0 * a2.powf(two)) / (two * (fx1 - f0 - df0 * a2))
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
let div = T::one() / (a1.powf(two) * a2.powf(two) * (a2 - a1));
|
let div = T::one() / (a1.powf(two) * a2.powf(two) * (a2 - a1));
|
||||||
let a = (a1.powf(two) * (fx1 - f0 - df0*a2) - a2.powf(two)*(fx0 - f0 - df0*a1))*div;
|
let a = (a1.powf(two) * (fx1 - f0 - df0 * a2)
|
||||||
let b = (-a1.powf(three) * (fx1 - f0 - df0*a2) + a2.powf(three)*(fx0 - f0 - df0*a1))*div;
|
- a2.powf(two) * (fx0 - f0 - df0 * a1))
|
||||||
|
* div;
|
||||||
|
let b = (-a1.powf(three) * (fx1 - f0 - df0 * a2)
|
||||||
|
+ a2.powf(three) * (fx0 - f0 - df0 * a1))
|
||||||
|
* div;
|
||||||
|
|
||||||
if (a - T::zero()).powf(two).sqrt() <= T::epsilon() {
|
if (a - T::zero()).powf(two).sqrt() <= T::epsilon() {
|
||||||
a_tmp = df0 / (two * b);
|
a_tmp = df0 / (two * b);
|
||||||
@@ -90,9 +103,8 @@ impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
|||||||
|
|
||||||
LineSearchResult {
|
LineSearchResult {
|
||||||
alpha: a2,
|
alpha: a2,
|
||||||
f_x: fx1
|
f_x: fx1,
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,14 +114,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn backtracking() {
|
fn backtracking() {
|
||||||
|
let f = |x: f64| -> f64 { x.powf(2.) + x };
|
||||||
|
|
||||||
let f = |x: f64| -> f64 {
|
let df = |x: f64| -> f64 { 2. * x + 1. };
|
||||||
x.powf(2.) + x
|
|
||||||
};
|
|
||||||
|
|
||||||
let df = |x: f64| -> f64 {
|
|
||||||
2. * x + 1.
|
|
||||||
};
|
|
||||||
|
|
||||||
let ls: Backtracking<f64> = Default::default();
|
let ls: Backtracking<f64> = Default::default();
|
||||||
|
|
||||||
|
|||||||
@@ -8,5 +8,5 @@ pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
|
|||||||
pub enum FunctionOrder {
|
pub enum FunctionOrder {
|
||||||
FIRST,
|
FIRST,
|
||||||
SECOND,
|
SECOND,
|
||||||
THIRD
|
THIRD,
|
||||||
}
|
}
|
||||||
@@ -1,20 +1,20 @@
|
|||||||
|
use std::collections::LinkedList;
|
||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::collections::LinkedList;
|
|
||||||
|
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::Matrix;
|
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
|
use crate::linalg::Matrix;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct DecisionTreeClassifierParameters {
|
pub struct DecisionTreeClassifierParameters {
|
||||||
pub criterion: SplitCriterion,
|
pub criterion: SplitCriterion,
|
||||||
pub max_depth: Option<u16>,
|
pub max_depth: Option<u16>,
|
||||||
pub min_samples_leaf: usize,
|
pub min_samples_leaf: usize,
|
||||||
pub min_samples_split: usize
|
pub min_samples_split: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
@@ -23,14 +23,14 @@ pub struct DecisionTreeClassifier<T: FloatExt> {
|
|||||||
parameters: DecisionTreeClassifierParameters,
|
parameters: DecisionTreeClassifierParameters,
|
||||||
num_classes: usize,
|
num_classes: usize,
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
depth: u16
|
depth: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub enum SplitCriterion {
|
pub enum SplitCriterion {
|
||||||
Gini,
|
Gini,
|
||||||
Entropy,
|
Entropy,
|
||||||
ClassificationError
|
ClassificationError,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
@@ -46,36 +46,37 @@ pub struct Node<T: FloatExt> {
|
|||||||
|
|
||||||
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
|
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.depth != other.depth ||
|
if self.depth != other.depth
|
||||||
self.num_classes != other.num_classes ||
|
|| self.num_classes != other.num_classes
|
||||||
self.nodes.len() != other.nodes.len(){
|
|| self.nodes.len() != other.nodes.len()
|
||||||
return false
|
{
|
||||||
|
return false;
|
||||||
} else {
|
} else {
|
||||||
for i in 0..self.classes.len() {
|
for i in 0..self.classes.len() {
|
||||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i in 0..self.nodes.len() {
|
for i in 0..self.nodes.len() {
|
||||||
if self.nodes[i] != other.nodes[i] {
|
if self.nodes[i] != other.nodes[i] {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> PartialEq for Node<T> {
|
impl<T: FloatExt> PartialEq for Node<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.output == other.output &&
|
self.output == other.output
|
||||||
self.split_feature == other.split_feature &&
|
&& self.split_feature == other.split_feature
|
||||||
match (self.split_value, other.split_value) {
|
&& match (self.split_value, other.split_value) {
|
||||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||||
(None, None) => true,
|
(None, None) => true,
|
||||||
_ => false,
|
_ => false,
|
||||||
} &&
|
}
|
||||||
match (self.split_score, other.split_score) {
|
&& match (self.split_score, other.split_score) {
|
||||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||||
(None, None) => true,
|
(None, None) => true,
|
||||||
_ => false,
|
_ => false,
|
||||||
@@ -89,7 +90,7 @@ impl Default for DecisionTreeClassifierParameters {
|
|||||||
criterion: SplitCriterion::Gini,
|
criterion: SplitCriterion::Gini,
|
||||||
max_depth: None,
|
max_depth: None,
|
||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2
|
min_samples_split: 2,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -103,7 +104,7 @@ impl<T: FloatExt> Node<T> {
|
|||||||
split_value: Option::None,
|
split_value: Option::None,
|
||||||
split_score: Option::None,
|
split_score: Option::None,
|
||||||
true_child: Option::None,
|
true_child: Option::None,
|
||||||
false_child: Option::None
|
false_child: Option::None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -117,7 +118,7 @@ struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
|||||||
true_child_output: usize,
|
true_child_output: usize,
|
||||||
false_child_output: usize,
|
false_child_output: usize,
|
||||||
level: u16,
|
level: u16,
|
||||||
phantom: PhantomData<&'a T>
|
phantom: PhantomData<&'a T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> T {
|
fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> T {
|
||||||
@@ -156,8 +157,14 @@ fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usiz
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||||
|
fn new(
|
||||||
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self {
|
node_id: usize,
|
||||||
|
samples: Vec<usize>,
|
||||||
|
order: &'a Vec<Vec<usize>>,
|
||||||
|
x: &'a M,
|
||||||
|
y: &'a Vec<usize>,
|
||||||
|
level: u16,
|
||||||
|
) -> Self {
|
||||||
NodeVisitor {
|
NodeVisitor {
|
||||||
x: x,
|
x: x,
|
||||||
y: y,
|
y: y,
|
||||||
@@ -167,10 +174,9 @@ impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
|||||||
true_child_output: 0,
|
true_child_output: 0,
|
||||||
false_child_output: 0,
|
false_child_output: 0,
|
||||||
level: level,
|
level: level,
|
||||||
phantom: PhantomData
|
phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
|
pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
|
||||||
@@ -188,14 +194,23 @@ pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> DecisionTreeClassifier<T> {
|
impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||||
|
pub fn fit<M: Matrix<T>>(
|
||||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier<T> {
|
x: &M,
|
||||||
|
y: &M::RowVector,
|
||||||
|
parameters: DecisionTreeClassifierParameters,
|
||||||
|
) -> DecisionTreeClassifier<T> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier<T> {
|
pub fn fit_weak_learner<M: Matrix<T>>(
|
||||||
|
x: &M,
|
||||||
|
y: &M::RowVector,
|
||||||
|
samples: Vec<usize>,
|
||||||
|
mtry: usize,
|
||||||
|
parameters: DecisionTreeClassifierParameters,
|
||||||
|
) -> DecisionTreeClassifier<T> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (_, y_ncols) = y_m.shape();
|
let (_, y_ncols) = y_m.shape();
|
||||||
let (_, num_attributes) = x.shape();
|
let (_, num_attributes) = x.shape();
|
||||||
@@ -232,7 +247,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
parameters: parameters,
|
parameters: parameters,
|
||||||
num_classes: k,
|
num_classes: k,
|
||||||
classes: classes,
|
classes: classes,
|
||||||
depth: 0
|
depth: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &yi, 1);
|
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &yi, 1);
|
||||||
@@ -245,8 +260,8 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
|
|
||||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||||
match visitor_queue.pop_front() {
|
match visitor_queue.pop_front() {
|
||||||
Some(node) => tree.split(node, mtry, &mut visitor_queue,),
|
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,17 +299,19 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
queue.push_back(node.false_child.unwrap());
|
queue.push_back(node.false_child.unwrap());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_best_cutoff<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, mtry: usize) -> bool {
|
fn find_best_cutoff<M: Matrix<T>>(
|
||||||
|
&mut self,
|
||||||
|
visitor: &mut NodeVisitor<T, M>,
|
||||||
|
mtry: usize,
|
||||||
|
) -> bool {
|
||||||
let (n_rows, n_attr) = visitor.x.shape();
|
let (n_rows, n_attr) = visitor.x.shape();
|
||||||
|
|
||||||
let mut label = Option::None;
|
let mut label = Option::None;
|
||||||
@@ -336,15 +353,28 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for j in 0..mtry {
|
for j in 0..mtry {
|
||||||
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
|
self.find_best_split(
|
||||||
|
visitor,
|
||||||
|
n,
|
||||||
|
&count,
|
||||||
|
&mut false_count,
|
||||||
|
parent_impurity,
|
||||||
|
variables[j],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.nodes[visitor.node].split_score != Option::None
|
self.nodes[visitor.node].split_score != Option::None
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_best_split<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, n: usize, count: &Vec<usize>, false_count: &mut Vec<usize>, parent_impurity: T, j: usize){
|
fn find_best_split<M: Matrix<T>>(
|
||||||
|
&mut self,
|
||||||
|
visitor: &mut NodeVisitor<T, M>,
|
||||||
|
n: usize,
|
||||||
|
count: &Vec<usize>,
|
||||||
|
false_count: &mut Vec<usize>,
|
||||||
|
parent_impurity: T,
|
||||||
|
j: usize,
|
||||||
|
) {
|
||||||
let mut true_count = vec![0; self.num_classes];
|
let mut true_count = vec![0; self.num_classes];
|
||||||
let mut prevx = T::nan();
|
let mut prevx = T::nan();
|
||||||
let mut prevy = 0;
|
let mut prevy = 0;
|
||||||
@@ -374,11 +404,18 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
|
|
||||||
let true_label = which_max(&true_count);
|
let true_label = which_max(&true_count);
|
||||||
let false_label = which_max(false_count);
|
let false_label = which_max(false_count);
|
||||||
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
|
let gain = parent_impurity
|
||||||
|
- T::from(tc).unwrap() / T::from(n).unwrap()
|
||||||
|
* impurity(&self.parameters.criterion, &true_count, tc)
|
||||||
|
- T::from(fc).unwrap() / T::from(n).unwrap()
|
||||||
|
* impurity(&self.parameters.criterion, &false_count, fc);
|
||||||
|
|
||||||
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
if self.nodes[visitor.node].split_score == Option::None
|
||||||
|
|| gain > self.nodes[visitor.node].split_score.unwrap()
|
||||||
|
{
|
||||||
self.nodes[visitor.node].split_feature = j;
|
self.nodes[visitor.node].split_feature = j;
|
||||||
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
self.nodes[visitor.node].split_value =
|
||||||
|
Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||||
self.nodes[visitor.node].split_score = Option::Some(gain);
|
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||||
visitor.true_child_output = true_label;
|
visitor.true_child_output = true_label;
|
||||||
visitor.false_child_output = false_label;
|
visitor.false_child_output = false_label;
|
||||||
@@ -389,10 +426,14 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
true_count[visitor.y[*i]] += visitor.samples[*i];
|
true_count[visitor.y[*i]] += visitor.samples[*i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split<'a, M: Matrix<T>>(&mut self, mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>) -> bool {
|
fn split<'a, M: Matrix<T>>(
|
||||||
|
&mut self,
|
||||||
|
mut visitor: NodeVisitor<'a, T, M>,
|
||||||
|
mtry: usize,
|
||||||
|
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||||
|
) -> bool {
|
||||||
let (n, _) = visitor.x.shape();
|
let (n, _) = visitor.x.shape();
|
||||||
let mut tc = 0;
|
let mut tc = 0;
|
||||||
let mut fc = 0;
|
let mut fc = 0;
|
||||||
@@ -400,7 +441,9 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
if visitor.samples[i] > 0 {
|
if visitor.samples[i] > 0 {
|
||||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
if visitor.x.get(i, self.nodes[visitor.node].split_feature)
|
||||||
|
<= self.nodes[visitor.node].split_value.unwrap_or(T::nan())
|
||||||
|
{
|
||||||
true_samples[i] = visitor.samples[i];
|
true_samples[i] = visitor.samples[i];
|
||||||
tc += true_samples[i];
|
tc += true_samples[i];
|
||||||
visitor.samples[i] = 0;
|
visitor.samples[i] = 0;
|
||||||
@@ -418,22 +461,38 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let true_child_idx = self.nodes.len();
|
let true_child_idx = self.nodes.len();
|
||||||
self.nodes.push(Node::new(true_child_idx, visitor.true_child_output));
|
self.nodes
|
||||||
|
.push(Node::new(true_child_idx, visitor.true_child_output));
|
||||||
let false_child_idx = self.nodes.len();
|
let false_child_idx = self.nodes.len();
|
||||||
self.nodes.push(Node::new(false_child_idx, visitor.false_child_output));
|
self.nodes
|
||||||
|
.push(Node::new(false_child_idx, visitor.false_child_output));
|
||||||
|
|
||||||
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
||||||
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
||||||
|
|
||||||
self.depth = u16::max(self.depth, visitor.level + 1);
|
self.depth = u16::max(self.depth, visitor.level + 1);
|
||||||
|
|
||||||
let mut true_visitor = NodeVisitor::<T, M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
let mut true_visitor = NodeVisitor::<T, M>::new(
|
||||||
|
true_child_idx,
|
||||||
|
true_samples,
|
||||||
|
visitor.order,
|
||||||
|
visitor.x,
|
||||||
|
visitor.y,
|
||||||
|
visitor.level + 1,
|
||||||
|
);
|
||||||
|
|
||||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||||
visitor_queue.push_back(true_visitor);
|
visitor_queue.push_back(true_visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut false_visitor = NodeVisitor::<T, M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
let mut false_visitor = NodeVisitor::<T, M>::new(
|
||||||
|
false_child_idx,
|
||||||
|
visitor.samples,
|
||||||
|
visitor.order,
|
||||||
|
visitor.x,
|
||||||
|
visitor.y,
|
||||||
|
visitor.level + 1,
|
||||||
|
);
|
||||||
|
|
||||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||||
visitor_queue.push_back(false_visitor);
|
visitor_queue.push_back(false_visitor);
|
||||||
@@ -441,7 +500,6 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
|||||||
|
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -451,14 +509,22 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn gini_impurity() {
|
fn gini_impurity() {
|
||||||
assert!((impurity::<f64>(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs() < std::f64::EPSILON);
|
assert!(
|
||||||
assert!((impurity::<f64>(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs() < std::f64::EPSILON);
|
(impurity::<f64>(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs()
|
||||||
assert!((impurity::<f64>(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs() < std::f64::EPSILON);
|
< std::f64::EPSILON
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
(impurity::<f64>(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs()
|
||||||
|
< std::f64::EPSILON
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
(impurity::<f64>(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs()
|
||||||
|
< std::f64::EPSILON
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
@@ -479,18 +545,35 @@ mod tests {
|
|||||||
&[6.3, 3.3, 4.7, 1.6],
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
&[4.9, 2.4, 3.3, 1.0],
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
&[6.6, 2.9, 4.6, 1.3],
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
&[5.2, 2.7, 3.9, 1.4]]);
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
]);
|
||||||
|
let y = vec![
|
||||||
|
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
|
];
|
||||||
|
|
||||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
assert_eq!(
|
||||||
|
y,
|
||||||
assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1, min_samples_split: 2}).depth);
|
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
3,
|
||||||
|
DecisionTreeClassifier::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
DecisionTreeClassifierParameters {
|
||||||
|
criterion: SplitCriterion::Entropy,
|
||||||
|
max_depth: Some(3),
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.depth
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_baloons() {
|
fn fit_predict_baloons() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[1., 1., 1., 0.],
|
&[1., 1., 1., 0.],
|
||||||
&[1., 1., 1., 0.],
|
&[1., 1., 1., 0.],
|
||||||
@@ -511,11 +594,16 @@ mod tests {
|
|||||||
&[0., 0., 1., 0.],
|
&[0., 0., 1., 0.],
|
||||||
&[0., 0., 1., 1.],
|
&[0., 0., 1., 1.],
|
||||||
&[0., 0., 0., 0.],
|
&[0., 0., 0., 0.],
|
||||||
&[0.,0.,0.,1.]]);
|
&[0., 0., 0., 1.],
|
||||||
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
]);
|
||||||
|
let y = vec![
|
||||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
|
||||||
|
];
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
y,
|
||||||
|
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -540,14 +628,17 @@ mod tests {
|
|||||||
&[0., 0., 1., 0.],
|
&[0., 0., 1., 0.],
|
||||||
&[0., 0., 1., 1.],
|
&[0., 0., 1., 1.],
|
||||||
&[0., 0., 0., 0.],
|
&[0., 0., 0., 0.],
|
||||||
&[0.,0.,0.,1.]]);
|
&[0., 0., 0., 1.],
|
||||||
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
]);
|
||||||
|
let y = vec![
|
||||||
|
1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
|
||||||
|
];
|
||||||
|
|
||||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
||||||
|
|
||||||
let deserialized_tree: DecisionTreeClassifier<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
let deserialized_tree: DecisionTreeClassifier<f64> =
|
||||||
|
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(tree, deserialized_tree);
|
assert_eq!(tree, deserialized_tree);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,25 +1,25 @@
|
|||||||
|
use std::collections::LinkedList;
|
||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::collections::LinkedList;
|
|
||||||
|
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
|
||||||
use crate::linalg::Matrix;
|
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
|
use crate::linalg::Matrix;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct DecisionTreeRegressorParameters {
|
pub struct DecisionTreeRegressorParameters {
|
||||||
pub max_depth: Option<u16>,
|
pub max_depth: Option<u16>,
|
||||||
pub min_samples_leaf: usize,
|
pub min_samples_leaf: usize,
|
||||||
pub min_samples_split: usize
|
pub min_samples_split: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct DecisionTreeRegressor<T: FloatExt> {
|
pub struct DecisionTreeRegressor<T: FloatExt> {
|
||||||
nodes: Vec<Node<T>>,
|
nodes: Vec<Node<T>>,
|
||||||
parameters: DecisionTreeRegressorParameters,
|
parameters: DecisionTreeRegressorParameters,
|
||||||
depth: u16
|
depth: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
@@ -33,13 +33,12 @@ pub struct Node<T: FloatExt> {
|
|||||||
false_child: Option<usize>,
|
false_child: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl Default for DecisionTreeRegressorParameters {
|
impl Default for DecisionTreeRegressorParameters {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
DecisionTreeRegressorParameters {
|
DecisionTreeRegressorParameters {
|
||||||
max_depth: None,
|
max_depth: None,
|
||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2
|
min_samples_split: 2,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -53,21 +52,21 @@ impl<T: FloatExt> Node<T> {
|
|||||||
split_value: Option::None,
|
split_value: Option::None,
|
||||||
split_score: Option::None,
|
split_score: Option::None,
|
||||||
true_child: Option::None,
|
true_child: Option::None,
|
||||||
false_child: Option::None
|
false_child: Option::None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> PartialEq for Node<T> {
|
impl<T: FloatExt> PartialEq for Node<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
(self.output - other.output).abs() < T::epsilon() &&
|
(self.output - other.output).abs() < T::epsilon()
|
||||||
self.split_feature == other.split_feature &&
|
&& self.split_feature == other.split_feature
|
||||||
match (self.split_value, other.split_value) {
|
&& match (self.split_value, other.split_value) {
|
||||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||||
(None, None) => true,
|
(None, None) => true,
|
||||||
_ => false,
|
_ => false,
|
||||||
} &&
|
}
|
||||||
match (self.split_score, other.split_score) {
|
&& match (self.split_score, other.split_score) {
|
||||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||||
(None, None) => true,
|
(None, None) => true,
|
||||||
_ => false,
|
_ => false,
|
||||||
@@ -78,14 +77,14 @@ impl<T: FloatExt> PartialEq for Node<T> {
|
|||||||
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
|
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.depth != other.depth || self.nodes.len() != other.nodes.len() {
|
if self.depth != other.depth || self.nodes.len() != other.nodes.len() {
|
||||||
return false
|
return false;
|
||||||
} else {
|
} else {
|
||||||
for i in 0..self.nodes.len() {
|
for i in 0..self.nodes.len() {
|
||||||
if self.nodes[i] != other.nodes[i] {
|
if self.nodes[i] != other.nodes[i] {
|
||||||
return false
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -98,12 +97,18 @@ struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
|||||||
order: &'a Vec<Vec<usize>>,
|
order: &'a Vec<Vec<usize>>,
|
||||||
true_child_output: T,
|
true_child_output: T,
|
||||||
false_child_output: T,
|
false_child_output: T,
|
||||||
level: u16
|
level: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||||
|
fn new(
|
||||||
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a M, level: u16) -> Self {
|
node_id: usize,
|
||||||
|
samples: Vec<usize>,
|
||||||
|
order: &'a Vec<Vec<usize>>,
|
||||||
|
x: &'a M,
|
||||||
|
y: &'a M,
|
||||||
|
level: u16,
|
||||||
|
) -> Self {
|
||||||
NodeVisitor {
|
NodeVisitor {
|
||||||
x: x,
|
x: x,
|
||||||
y: y,
|
y: y,
|
||||||
@@ -112,21 +117,29 @@ impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
|||||||
order: order,
|
order: order,
|
||||||
true_child_output: T::zero(),
|
true_child_output: T::zero(),
|
||||||
false_child_output: T::zero(),
|
false_child_output: T::zero(),
|
||||||
level: level
|
level: level,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt> DecisionTreeRegressor<T> {
|
impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||||
|
pub fn fit<M: Matrix<T>>(
|
||||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor<T> {
|
x: &M,
|
||||||
|
y: &M::RowVector,
|
||||||
|
parameters: DecisionTreeRegressorParameters,
|
||||||
|
) -> DecisionTreeRegressor<T> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor<T> {
|
pub fn fit_weak_learner<M: Matrix<T>>(
|
||||||
|
x: &M,
|
||||||
|
y: &M::RowVector,
|
||||||
|
samples: Vec<usize>,
|
||||||
|
mtry: usize,
|
||||||
|
parameters: DecisionTreeRegressorParameters,
|
||||||
|
) -> DecisionTreeRegressor<T> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
|
|
||||||
let (_, y_ncols) = y_m.shape();
|
let (_, y_ncols) = y_m.shape();
|
||||||
@@ -157,7 +170,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
let mut tree = DecisionTreeRegressor {
|
let mut tree = DecisionTreeRegressor {
|
||||||
nodes: nodes,
|
nodes: nodes,
|
||||||
parameters: parameters,
|
parameters: parameters,
|
||||||
depth: 0
|
depth: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &y_m, 1);
|
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &y_m, 1);
|
||||||
@@ -171,7 +184,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||||
match visitor_queue.pop_front() {
|
match visitor_queue.pop_front() {
|
||||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,17 +222,19 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
queue.push_back(node.false_child.unwrap());
|
queue.push_back(node.false_child.unwrap());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_best_cutoff<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, mtry: usize) -> bool {
|
fn find_best_cutoff<M: Matrix<T>>(
|
||||||
|
&mut self,
|
||||||
|
visitor: &mut NodeVisitor<T, M>,
|
||||||
|
mtry: usize,
|
||||||
|
) -> bool {
|
||||||
let (_, n_attr) = visitor.x.shape();
|
let (_, n_attr) = visitor.x.shape();
|
||||||
|
|
||||||
let n: usize = visitor.samples.iter().sum();
|
let n: usize = visitor.samples.iter().sum();
|
||||||
@@ -235,18 +250,24 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
variables[i] = i;
|
variables[i] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
let parent_gain = T::from(n).unwrap() * self.nodes[visitor.node].output * self.nodes[visitor.node].output;
|
let parent_gain =
|
||||||
|
T::from(n).unwrap() * self.nodes[visitor.node].output * self.nodes[visitor.node].output;
|
||||||
|
|
||||||
for j in 0..mtry {
|
for j in 0..mtry {
|
||||||
self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
|
self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.nodes[visitor.node].split_score != Option::None
|
self.nodes[visitor.node].split_score != Option::None
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_best_split<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, n: usize, sum: T, parent_gain: T, j: usize){
|
fn find_best_split<M: Matrix<T>>(
|
||||||
|
&mut self,
|
||||||
|
visitor: &mut NodeVisitor<T, M>,
|
||||||
|
n: usize,
|
||||||
|
sum: T,
|
||||||
|
parent_gain: T,
|
||||||
|
j: usize,
|
||||||
|
) {
|
||||||
let mut true_sum = T::zero();
|
let mut true_sum = T::zero();
|
||||||
let mut true_count = 0;
|
let mut true_count = 0;
|
||||||
let mut prevx = T::nan();
|
let mut prevx = T::nan();
|
||||||
@@ -256,27 +277,36 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
if prevx.is_nan() || visitor.x.get(*i, j) == prevx {
|
if prevx.is_nan() || visitor.x.get(*i, j) == prevx {
|
||||||
prevx = visitor.x.get(*i, j);
|
prevx = visitor.x.get(*i, j);
|
||||||
true_count += visitor.samples[*i];
|
true_count += visitor.samples[*i];
|
||||||
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
true_sum =
|
||||||
|
true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let false_count = n - true_count;
|
let false_count = n - true_count;
|
||||||
|
|
||||||
if true_count < self.parameters.min_samples_leaf || false_count < self.parameters.min_samples_leaf {
|
if true_count < self.parameters.min_samples_leaf
|
||||||
|
|| false_count < self.parameters.min_samples_leaf
|
||||||
|
{
|
||||||
prevx = visitor.x.get(*i, j);
|
prevx = visitor.x.get(*i, j);
|
||||||
true_count += visitor.samples[*i];
|
true_count += visitor.samples[*i];
|
||||||
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
true_sum =
|
||||||
|
true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let true_mean = true_sum / T::from(true_count).unwrap();
|
let true_mean = true_sum / T::from(true_count).unwrap();
|
||||||
let false_mean = (sum - true_sum) / T::from(false_count).unwrap();
|
let false_mean = (sum - true_sum) / T::from(false_count).unwrap();
|
||||||
|
|
||||||
let gain = (T::from(true_count).unwrap() * true_mean * true_mean + T::from(false_count).unwrap() * false_mean * false_mean) - parent_gain;
|
let gain = (T::from(true_count).unwrap() * true_mean * true_mean
|
||||||
|
+ T::from(false_count).unwrap() * false_mean * false_mean)
|
||||||
|
- parent_gain;
|
||||||
|
|
||||||
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
if self.nodes[visitor.node].split_score == Option::None
|
||||||
|
|| gain > self.nodes[visitor.node].split_score.unwrap()
|
||||||
|
{
|
||||||
self.nodes[visitor.node].split_feature = j;
|
self.nodes[visitor.node].split_feature = j;
|
||||||
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
self.nodes[visitor.node].split_value =
|
||||||
|
Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||||
self.nodes[visitor.node].split_score = Option::Some(gain);
|
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||||
visitor.true_child_output = true_mean;
|
visitor.true_child_output = true_mean;
|
||||||
visitor.false_child_output = false_mean;
|
visitor.false_child_output = false_mean;
|
||||||
@@ -287,10 +317,14 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
true_count += visitor.samples[*i];
|
true_count += visitor.samples[*i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split<'a, M: Matrix<T>>(&mut self, mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>) -> bool {
|
fn split<'a, M: Matrix<T>>(
|
||||||
|
&mut self,
|
||||||
|
mut visitor: NodeVisitor<'a, T, M>,
|
||||||
|
mtry: usize,
|
||||||
|
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||||
|
) -> bool {
|
||||||
let (n, _) = visitor.x.shape();
|
let (n, _) = visitor.x.shape();
|
||||||
let mut tc = 0;
|
let mut tc = 0;
|
||||||
let mut fc = 0;
|
let mut fc = 0;
|
||||||
@@ -298,7 +332,9 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
if visitor.samples[i] > 0 {
|
if visitor.samples[i] > 0 {
|
||||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
if visitor.x.get(i, self.nodes[visitor.node].split_feature)
|
||||||
|
<= self.nodes[visitor.node].split_value.unwrap_or(T::nan())
|
||||||
|
{
|
||||||
true_samples[i] = visitor.samples[i];
|
true_samples[i] = visitor.samples[i];
|
||||||
tc += true_samples[i];
|
tc += true_samples[i];
|
||||||
visitor.samples[i] = 0;
|
visitor.samples[i] = 0;
|
||||||
@@ -316,22 +352,38 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let true_child_idx = self.nodes.len();
|
let true_child_idx = self.nodes.len();
|
||||||
self.nodes.push(Node::new(true_child_idx, visitor.true_child_output));
|
self.nodes
|
||||||
|
.push(Node::new(true_child_idx, visitor.true_child_output));
|
||||||
let false_child_idx = self.nodes.len();
|
let false_child_idx = self.nodes.len();
|
||||||
self.nodes.push(Node::new(false_child_idx, visitor.false_child_output));
|
self.nodes
|
||||||
|
.push(Node::new(false_child_idx, visitor.false_child_output));
|
||||||
|
|
||||||
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
||||||
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
||||||
|
|
||||||
self.depth = u16::max(self.depth, visitor.level + 1);
|
self.depth = u16::max(self.depth, visitor.level + 1);
|
||||||
|
|
||||||
let mut true_visitor = NodeVisitor::<T, M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
let mut true_visitor = NodeVisitor::<T, M>::new(
|
||||||
|
true_child_idx,
|
||||||
|
true_samples,
|
||||||
|
visitor.order,
|
||||||
|
visitor.x,
|
||||||
|
visitor.y,
|
||||||
|
visitor.level + 1,
|
||||||
|
);
|
||||||
|
|
||||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||||
visitor_queue.push_back(true_visitor);
|
visitor_queue.push_back(true_visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut false_visitor = NodeVisitor::<T, M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
let mut false_visitor = NodeVisitor::<T, M>::new(
|
||||||
|
false_child_idx,
|
||||||
|
visitor.samples,
|
||||||
|
visitor.order,
|
||||||
|
visitor.x,
|
||||||
|
visitor.y,
|
||||||
|
visitor.level + 1,
|
||||||
|
);
|
||||||
|
|
||||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||||
visitor_queue.push_back(false_visitor);
|
visitor_queue.push_back(false_visitor);
|
||||||
@@ -339,7 +391,6 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
|||||||
|
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -349,7 +400,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_longley() {
|
fn fit_longley() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
@@ -366,8 +416,12 @@ mod tests {
|
|||||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
]);
|
||||||
|
let y: Vec<f64> = vec![
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
];
|
||||||
|
|
||||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
|
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
|
||||||
|
|
||||||
@@ -375,20 +429,43 @@ mod tests {
|
|||||||
assert!((y_hat[i] - y[i]).abs() < 0.1);
|
assert!((y_hat[i] - y[i]).abs() < 0.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
let expected_y = vec![87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85, 114.85, 114.85, 114.85];
|
let expected_y = vec![
|
||||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, DecisionTreeRegressorParameters{max_depth: Option::None, min_samples_leaf: 2, min_samples_split: 6}).predict(&x);
|
87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85,
|
||||||
|
114.85, 114.85, 114.85,
|
||||||
|
];
|
||||||
|
let y_hat = DecisionTreeRegressor::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
DecisionTreeRegressorParameters {
|
||||||
|
max_depth: Option::None,
|
||||||
|
min_samples_leaf: 2,
|
||||||
|
min_samples_split: 6,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.predict(&x);
|
||||||
|
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
let expected_y = vec![83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4, 113.4, 116.30, 116.30];
|
let expected_y = vec![
|
||||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, DecisionTreeRegressorParameters{max_depth: Option::None, min_samples_leaf: 1, min_samples_split: 3}).predict(&x);
|
83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4,
|
||||||
|
113.4, 116.30, 116.30,
|
||||||
|
];
|
||||||
|
let y_hat = DecisionTreeRegressor::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
DecisionTreeRegressorParameters {
|
||||||
|
max_depth: Option::None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.predict(&x);
|
||||||
|
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -409,15 +486,18 @@ mod tests {
|
|||||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
]);
|
||||||
|
let y: Vec<f64> = vec![
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
];
|
||||||
|
|
||||||
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
||||||
|
|
||||||
let deserialized_tree: DecisionTreeRegressor<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
let deserialized_tree: DecisionTreeRegressor<f64> =
|
||||||
|
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(tree, deserialized_tree);
|
assert_eq!(tree, deserialized_tree);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
+1
-1
@@ -1,2 +1,2 @@
|
|||||||
pub mod decision_tree_regressor;
|
|
||||||
pub mod decision_tree_classifier;
|
pub mod decision_tree_classifier;
|
||||||
|
pub mod decision_tree_regressor;
|
||||||
|
|||||||
Reference in New Issue
Block a user