fix: cargo fmt

This commit is contained in:
Volodymyr Orlov
2020-06-05 17:52:03 -07:00
parent 685be04488
commit a2784d6345
52 changed files with 3342 additions and 2829 deletions
+5 -3
View File
@@ -2,15 +2,17 @@
extern crate criterion; extern crate criterion;
extern crate smartcore; extern crate smartcore;
use criterion::Criterion;
use criterion::black_box; use criterion::black_box;
use criterion::Criterion;
use smartcore::math::distance::*; use smartcore::math::distance::*;
fn criterion_benchmark(c: &mut Criterion) { fn criterion_benchmark(c: &mut Criterion) {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
c.bench_function("Euclidean Distance", move |b| b.iter(|| Distances::euclidian().distance(black_box(&a), black_box(&a)))); c.bench_function("Euclidean Distance", move |b| {
b.iter(|| Distances::euclidian().distance(black_box(&a), black_box(&a)))
});
} }
criterion_group!(benches, criterion_benchmark); criterion_group!(benches, criterion_benchmark);
criterion_main!(benches); criterion_main!(benches);
+1 -1
View File
@@ -1,2 +1,2 @@
pub mod neighbour;
pub mod sort; pub mod sort;
pub mod neighbour;
+95 -50
View File
@@ -1,18 +1,18 @@
use std::fmt::Debug; use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::distance::euclidian::*; use crate::math::distance::euclidian::*;
use crate::math::num::FloatExt;
#[derive(Debug)] #[derive(Debug)]
pub struct BBDTree<T: FloatExt> { pub struct BBDTree<T: FloatExt> {
nodes: Vec<BBDTreeNode<T>>, nodes: Vec<BBDTreeNode<T>>,
index: Vec<usize>, index: Vec<usize>,
root: usize root: usize,
} }
#[derive(Debug)] #[derive(Debug)]
struct BBDTreeNode<T: FloatExt> { struct BBDTreeNode<T: FloatExt> {
count: usize, count: usize,
index: usize, index: usize,
center: Vec<T>, center: Vec<T>,
@@ -20,7 +20,7 @@ struct BBDTreeNode<T: FloatExt> {
sum: Vec<T>, sum: Vec<T>,
cost: T, cost: T,
lower: Option<usize>, lower: Option<usize>,
upper: Option<usize> upper: Option<usize>,
} }
impl<T: FloatExt> BBDTreeNode<T> { impl<T: FloatExt> BBDTreeNode<T> {
@@ -33,7 +33,7 @@ impl<T: FloatExt> BBDTreeNode<T> {
sum: vec![T::zero(); d], sum: vec![T::zero(); d],
cost: T::zero(), cost: T::zero(),
lower: Option::None, lower: Option::None,
upper: Option::None upper: Option::None,
} }
} }
} }
@@ -49,10 +49,10 @@ impl<T: FloatExt> BBDTree<T> {
index[i] = i; index[i] = i;
} }
let mut tree = BBDTree{ let mut tree = BBDTree {
nodes: nodes, nodes: nodes,
index: index, index: index,
root: 0 root: 0,
}; };
let root = tree.build_node(data, 0, n); let root = tree.build_node(data, 0, n);
@@ -60,29 +60,54 @@ impl<T: FloatExt> BBDTree<T> {
tree.root = root; tree.root = root;
tree tree
} }
pub(in crate) fn clustering(&self, centroids: &Vec<Vec<T>>, sums: &mut Vec<Vec<T>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> T { pub(in crate) fn clustering(
&self,
centroids: &Vec<Vec<T>>,
sums: &mut Vec<Vec<T>>,
counts: &mut Vec<usize>,
membership: &mut Vec<usize>,
) -> T {
let k = centroids.len(); let k = centroids.len();
counts.iter_mut().for_each(|x| *x = 0); counts.iter_mut().for_each(|x| *x = 0);
let mut candidates = vec![0; k]; let mut candidates = vec![0; k];
for i in 0..k { for i in 0..k {
candidates[i] = i; candidates[i] = i;
sums[i].iter_mut().for_each(|x| *x = T::zero()); sums[i].iter_mut().for_each(|x| *x = T::zero());
} }
self.filter(self.root, centroids, &candidates, k, sums, counts, membership) self.filter(
self.root,
centroids,
&candidates,
k,
sums,
counts,
membership,
)
} }
fn filter(&self, node: usize, centroids: &Vec<Vec<T>>, candidates: &Vec<usize>, k: usize, sums: &mut Vec<Vec<T>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> T{ fn filter(
&self,
node: usize,
centroids: &Vec<Vec<T>>,
candidates: &Vec<usize>,
k: usize,
sums: &mut Vec<Vec<T>>,
counts: &mut Vec<usize>,
membership: &mut Vec<usize>,
) -> T {
let d = centroids[0].len(); let d = centroids[0].len();
// Determine which mean the node mean is closest to // Determine which mean the node mean is closest to
let mut min_dist = Euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[0]]); let mut min_dist =
Euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[0]]);
let mut closest = candidates[0]; let mut closest = candidates[0];
for i in 1..k { for i in 1..k {
let dist = Euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[i]]); let dist =
Euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[i]]);
if dist < min_dist { if dist < min_dist {
min_dist = dist; min_dist = dist;
closest = candidates[i]; closest = candidates[i];
@@ -92,11 +117,17 @@ impl<T: FloatExt> BBDTree<T> {
// If this is a non-leaf node, recurse if necessary // If this is a non-leaf node, recurse if necessary
if !self.nodes[node].lower.is_none() { if !self.nodes[node].lower.is_none() {
// Build the new list of candidates // Build the new list of candidates
let mut new_candidates = vec![0;k]; let mut new_candidates = vec![0; k];
let mut newk = 0; let mut newk = 0;
for i in 0..k { for i in 0..k {
if !BBDTree::prune(&self.nodes[node].center, &self.nodes[node].radius, &centroids, closest, candidates[i]) { if !BBDTree::prune(
&self.nodes[node].center,
&self.nodes[node].radius,
&centroids,
closest,
candidates[i],
) {
new_candidates[newk] = candidates[i]; new_candidates[newk] = candidates[i];
newk += 1; newk += 1;
} }
@@ -104,8 +135,23 @@ impl<T: FloatExt> BBDTree<T> {
// Recurse if there's at least two // Recurse if there's at least two
if newk > 1 { if newk > 1 {
let result = self.filter(self.nodes[node].lower.unwrap(), centroids, &mut new_candidates, newk, sums, counts, membership) + let result = self.filter(
self.filter(self.nodes[node].upper.unwrap(), centroids, &mut new_candidates, newk, sums, counts, membership); self.nodes[node].lower.unwrap(),
centroids,
&mut new_candidates,
newk,
sums,
counts,
membership,
) + self.filter(
self.nodes[node].upper.unwrap(),
centroids,
&mut new_candidates,
newk,
sums,
counts,
membership,
);
return result; return result;
} }
} }
@@ -116,17 +162,22 @@ impl<T: FloatExt> BBDTree<T> {
} }
counts[closest] += self.nodes[node].count; counts[closest] += self.nodes[node].count;
let last = self.nodes[node].index + self.nodes[node].count; let last = self.nodes[node].index + self.nodes[node].count;
for i in self.nodes[node].index..last { for i in self.nodes[node].index..last {
membership[self.index[i]] = closest; membership[self.index[i]] = closest;
} }
BBDTree::node_cost(&self.nodes[node], &centroids[closest]) BBDTree::node_cost(&self.nodes[node], &centroids[closest])
} }
fn prune(center: &Vec<T>, radius: &Vec<T>, centroids: &Vec<Vec<T>>, best_index: usize, test_index: usize) -> bool { fn prune(
center: &Vec<T>,
radius: &Vec<T>,
centroids: &Vec<Vec<T>>,
best_index: usize,
test_index: usize,
) -> bool {
if best_index == test_index { if best_index == test_index {
return false; return false;
} }
@@ -148,7 +199,7 @@ impl<T: FloatExt> BBDTree<T> {
} }
return lhs >= T::two() * rhs; return lhs >= T::two() * rhs;
} }
fn build_node<M: Matrix<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize { fn build_node<M: Matrix<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
let (_, d) = data.shape(); let (_, d) = data.shape();
@@ -165,8 +216,8 @@ impl<T: FloatExt> BBDTree<T> {
let mut upper_bound = vec![T::zero(); d]; let mut upper_bound = vec![T::zero(); d];
for i in 0..d { for i in 0..d {
lower_bound[i] = data.get(self.index[begin],i); lower_bound[i] = data.get(self.index[begin], i);
upper_bound[i] = data.get(self.index[begin],i); upper_bound[i] = data.get(self.index[begin], i);
} }
for i in begin..end { for i in begin..end {
@@ -200,7 +251,7 @@ impl<T: FloatExt> BBDTree<T> {
for i in 0..d { for i in 0..d {
node.sum[i] = data.get(self.index[begin], i); node.sum[i] = data.get(self.index[begin], i);
} }
if end > begin + 1 { if end > begin + 1 {
let len = end - begin; let len = end - begin;
for i in 0..d { for i in 0..d {
@@ -247,7 +298,8 @@ impl<T: FloatExt> BBDTree<T> {
// Calculate the new sum and opt cost // Calculate the new sum and opt cost
for i in 0..d { for i in 0..d {
node.sum[i] = self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i]; node.sum[i] =
self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
} }
let mut mean = vec![T::zero(); d]; let mut mean = vec![T::zero(); d];
@@ -255,7 +307,8 @@ impl<T: FloatExt> BBDTree<T> {
mean[i] = node.sum[i] / T::from(node.count).unwrap(); mean[i] = node.sum[i] / T::from(node.count).unwrap();
} }
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean) + BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean); node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean)
+ BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
self.add_node(node) self.add_node(node)
} }
@@ -270,7 +323,7 @@ impl<T: FloatExt> BBDTree<T> {
node.cost + T::from(node.count).unwrap() * scatter node.cost + T::from(node.count).unwrap() * scatter
} }
fn add_node(&mut self, new_node: BBDTreeNode<T>) -> usize{ fn add_node(&mut self, new_node: BBDTreeNode<T>) -> usize {
let idx = self.nodes.len(); let idx = self.nodes.len();
self.nodes.push(new_node); self.nodes.push(new_node);
idx idx
@@ -279,12 +332,11 @@ impl<T: FloatExt> BBDTree<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test] #[test]
fn fit_predict_iris() { fn fit_predict_iris() {
let data = DenseMatrix::from_array(&[ let data = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2],
@@ -305,30 +357,23 @@ mod tests {
&[6.3, 3.3, 4.7, 1.6], &[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4]]); &[5.2, 2.7, 3.9, 1.4],
]);
let tree = BBDTree::new(&data); let tree = BBDTree::new(&data);
let centroids = vec![ let centroids = vec![vec![4.86, 3.22, 1.61, 0.29], vec![6.23, 2.92, 4.48, 1.42]];
vec![4.86, 3.22, 1.61, 0.29],
vec![6.23, 2.92, 4.48, 1.42]
];
let mut sums = vec![ let mut sums = vec![vec![0f64; 4], vec![0f64; 4]];
vec![0f64; 4],
vec![0f64; 4]
];
let mut counts = vec![11, 9]; let mut counts = vec![11, 9];
let mut membership = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1]; let mut membership = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1];
let dist = tree.clustering(&centroids, &mut sums, &mut counts, &mut membership); let dist = tree.clustering(&centroids, &mut sums, &mut counts, &mut membership);
assert!((dist - 10.68).abs() < 1e-2); assert!((dist - 10.68).abs() < 1e-2);
assert!((sums[0][0] - 48.6).abs() < 1e-2); assert!((sums[0][0] - 48.6).abs() < 1e-2);
assert!((sums[1][3] - 13.8).abs() < 1e-2); assert!((sums[1][3] - 13.8).abs() < 1e-2);
assert_eq!(membership[17], 1); assert_eq!(membership[17], 1);
} }
}
}
+170 -113
View File
@@ -1,116 +1,126 @@
use std::collections::{HashMap, HashSet};
use std::iter::FromIterator;
use std::fmt::Debug;
use core::hash::{Hash, Hasher}; use core::hash::{Hash, Hasher};
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::iter::FromIterator;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::math::distance::Distance;
use crate::algorithm::sort::heap_select::HeapSelect; use crate::algorithm::sort::heap_select::HeapSelect;
use crate::math::distance::Distance;
use crate::math::num::FloatExt;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct CoverTree<T, F: FloatExt, D: Distance<T, F>> pub struct CoverTree<T, F: FloatExt, D: Distance<T, F>> {
{
base: F, base: F,
max_level: i8, max_level: i8,
min_level: i8, min_level: i8,
distance: D, distance: D,
nodes: Vec<Node<T>> nodes: Vec<Node<T>>,
} }
impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D> impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D> {
{
pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> { pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
let mut tree = CoverTree { let mut tree = CoverTree {
base: F::two(), base: F::two(),
max_level: 100, max_level: 100,
min_level: 100, min_level: 100,
distance: distance, distance: distance,
nodes: Vec::new() nodes: Vec::new(),
}; };
let p = tree.new_node(None, data.remove(0)); let p = tree.new_node(None, data.remove(0));
tree.construct(p, data, Vec::new(), 10); tree.construct(p, data, Vec::new(), 10);
tree tree
} }
pub fn insert(&mut self, p: T) { pub fn insert(&mut self, p: T) {
if self.nodes.is_empty(){ if self.nodes.is_empty() {
self.new_node(None, p); self.new_node(None, p);
} else { } else {
let mut parent: Option<NodeId> = Option::None; let mut parent: Option<NodeId> = Option::None;
let mut p_i = 0; let mut p_i = 0;
let mut qi_p_ds = vec!((self.root(), self.distance.distance(&p, &self.root().data))); let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
let mut i = self.max_level; let mut i = self.max_level;
loop { loop {
let i_d = self.base.powf(F::from(i).unwrap()); let i_d = self.base.powf(F::from(i).unwrap());
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i); let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
let d_p_q = self.min_by_distance(&q_p_ds); let d_p_q = self.min_by_distance(&q_p_ds);
if d_p_q < F::epsilon() { if d_p_q < F::epsilon() {
return return;
} else if d_p_q > i_d { } else if d_p_q > i_d {
break; break;
} }
if self.min_by_distance(&qi_p_ds) <= self.base.powf(F::from(i).unwrap()){ if self.min_by_distance(&qi_p_ds) <= self.base.powf(F::from(i).unwrap()) {
parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, _)| n.index); parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, _)| n.index);
p_i = i; p_i = i;
} }
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &i_d).collect(); qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &i_d).collect();
i -= 1; i -= 1;
} }
let new_node = self.new_node(parent, p); let new_node = self.new_node(parent, p);
self.add_child(parent.unwrap(), new_node, p_i); self.add_child(parent.unwrap(), new_node, p_i);
self.min_level = i8::min(self.min_level, p_i-1); self.min_level = i8::min(self.min_level, p_i - 1);
} }
} }
pub fn new_node(&mut self, parent: Option<NodeId>, data: T) -> NodeId { pub fn new_node(&mut self, parent: Option<NodeId>, data: T) -> NodeId {
let next_index = self.nodes.len(); let next_index = self.nodes.len();
let node_id = NodeId { index: next_index }; let node_id = NodeId { index: next_index };
self.nodes.push( self.nodes.push(Node {
Node { index: node_id,
index: node_id, data: data,
data: data, parent: parent,
parent: parent, children: HashMap::new(),
children: HashMap::new() });
});
node_id node_id
}
pub fn find(&self, p: &T, k: usize) -> Vec<usize>{
let mut qi_p_ds = vec!((self.root(), self.distance.distance(&p, &self.root().data)));
for i in (self.min_level..self.max_level+1).rev() {
let i_d = self.base.powf(F::from(i).unwrap());
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
}
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
} }
fn split(&self, p_id: NodeId, r: F, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){ pub fn find(&self, p: &T, k: usize) -> Vec<usize> {
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
let mut my_near = (Vec::new(), Vec::new()); for i in (self.min_level..self.max_level + 1).rev() {
let i_d = self.base.powf(F::from(i).unwrap());
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
qi_p_ds = q_p_ds
.into_iter()
.filter(|(_, d)| d <= &(d_p_q + i_d))
.collect();
}
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
qi_p_ds[..usize::min(qi_p_ds.len(), k)]
.iter()
.map(|(n, _)| n.index.index)
.collect()
}
fn split(
&self,
p_id: NodeId,
r: F,
s1: &mut Vec<T>,
s2: Option<&mut Vec<T>>,
) -> (Vec<T>, Vec<T>) {
let mut my_near = (Vec::new(), Vec::new());
my_near = self.split_remove_s(p_id, r, s1, my_near); my_near = self.split_remove_s(p_id, r, s1, my_near);
for s in s2 { for s in s2 {
my_near = self.split_remove_s(p_id, r, s, my_near); my_near = self.split_remove_s(p_id, r, s, my_near);
} }
return my_near return my_near;
} }
fn split_remove_s(&self, p_id: NodeId, r: F, s: &mut Vec<T>, mut my_near: (Vec<T>, Vec<T>)) -> (Vec<T>, Vec<T>){ fn split_remove_s(
&self,
p_id: NodeId,
r: F,
s: &mut Vec<T>,
mut my_near: (Vec<T>, Vec<T>),
) -> (Vec<T>, Vec<T>) {
if s.len() > 0 { if s.len() > 0 {
let p = &self.nodes.get(p_id.index).unwrap().data; let p = &self.nodes.get(p_id.index).unwrap().data;
let mut i = 0; let mut i = 0;
@@ -118,61 +128,84 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
let d = self.distance.distance(p, &s[i]); let d = self.distance.distance(p, &s[i]);
if d <= r { if d <= r {
my_near.0.push(s.remove(i)); my_near.0.push(s.remove(i));
} else if d > r && d <= F::two() * r{ } else if d > r && d <= F::two() * r {
my_near.1.push(s.remove(i)); my_near.1.push(s.remove(i));
} else { } else {
i += 1; i += 1;
} }
} }
} }
return my_near return my_near;
} }
fn construct<'b>(&mut self, p: NodeId, mut near: Vec<T>, mut far: Vec<T>, i: i8) -> (NodeId, Vec<T>) { fn construct<'b>(
&mut self,
if near.len() < 1{ p: NodeId,
self.min_level = std::cmp::min(self.min_level, i); mut near: Vec<T>,
return (p, far); mut far: Vec<T>,
i: i8,
) -> (NodeId, Vec<T>) {
if near.len() < 1 {
self.min_level = std::cmp::min(self.min_level, i);
return (p, far);
} else { } else {
let (my, n) = self.split(p, self.base.powf(F::from(i-1).unwrap()), &mut near, None); let (my, n) = self.split(p, self.base.powf(F::from(i - 1).unwrap()), &mut near, None);
let (pi, mut near) = self.construct(p, my, n, i-1); let (pi, mut near) = self.construct(p, my, n, i - 1);
while near.len() > 0 { while near.len() > 0 {
let q_data = near.remove(0); let q_data = near.remove(0);
let nn = self.new_node(Some(p), q_data); let nn = self.new_node(Some(p), q_data);
let (my, n) = self.split(nn, self.base.powf(F::from(i-1).unwrap()), &mut near, Some(&mut far)); let (my, n) = self.split(
let (child, mut unused) = self.construct(nn, my, n, i-1); nn,
self.base.powf(F::from(i - 1).unwrap()),
&mut near,
Some(&mut far),
);
let (child, mut unused) = self.construct(nn, my, n, i - 1);
self.add_child(pi, child, i); self.add_child(pi, child, i);
let new_near_far = self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None); let new_near_far =
self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None);
near.extend(new_near_far.0); near.extend(new_near_far.0);
far.extend(new_near_far.1); far.extend(new_near_far.1);
} }
self.min_level = std::cmp::min(self.min_level, i); self.min_level = std::cmp::min(self.min_level, i);
return (pi, far); return (pi, far);
} }
} }
fn add_child(&mut self, parent: NodeId, node: NodeId, i: i8){ fn add_child(&mut self, parent: NodeId, node: NodeId, i: i8) {
self.nodes.get_mut(parent.index).unwrap().children.insert(i, node); self.nodes
.get_mut(parent.index)
.unwrap()
.children
.insert(i, node);
} }
fn root(&self) -> &Node<T> { fn root(&self) -> &Node<T> {
self.nodes.first().unwrap() self.nodes.first().unwrap()
} }
fn get_children_dist<'b>(&'b self, p: &T, qi_p_ds: &Vec<(&'b Node<T>, F)>, i: i8) -> Vec<(&'b Node<T>, F)> { fn get_children_dist<'b>(
&'b self,
p: &T,
qi_p_ds: &Vec<(&'b Node<T>, F)>,
i: i8,
) -> Vec<(&'b Node<T>, F)> {
let mut children = Vec::<(&'b Node<T>, F)>::new(); let mut children = Vec::<(&'b Node<T>, F)>::new();
children.extend(qi_p_ds.iter().cloned()); children.extend(qi_p_ds.iter().cloned());
let q: Vec<&Node<T>> = qi_p_ds.iter().flat_map(|(n, _)| self.get_child(n, i)).collect(); let q: Vec<&Node<T>> = qi_p_ds
.iter()
.flat_map(|(n, _)| self.get_child(n, i))
.collect();
children.extend(q.into_iter().map(|n| (n, self.distance.distance(&n.data, &p)))); children.extend(
q.into_iter()
.map(|n| (n, self.distance.distance(&n.data, &p))),
);
children children
} }
fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, F)>, k: usize) -> F { fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, F)>, k: usize) -> F {
@@ -185,18 +218,27 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
} }
fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, F)>) -> F { fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, F)>) -> F {
q_p_ds.into_iter().min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()).unwrap().1 q_p_ds
.into_iter()
.min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap())
.unwrap()
.1
} }
fn get_child(&self, node: &Node<T>, i: i8) -> Option<&Node<T>> { fn get_child(&self, node: &Node<T>, i: i8) -> Option<&Node<T>> {
node.children.get(&i).and_then(|n_id| self.nodes.get(n_id.index)) node.children
} .get(&i)
.and_then(|n_id| self.nodes.get(n_id.index))
}
#[allow(dead_code)] #[allow(dead_code)]
fn check_invariant(&self, invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) { fn check_invariant(
&self,
invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> (),
) {
let mut current_nodes: Vec<&Node<T>> = Vec::new(); let mut current_nodes: Vec<&Node<T>> = Vec::new();
current_nodes.push(self.root()); current_nodes.push(self.root());
for i in (self.min_level..self.max_level+1).rev() { for i in (self.min_level..self.max_level + 1).rev() {
let mut next_nodes: Vec<&Node<T>> = Vec::new(); let mut next_nodes: Vec<&Node<T>> = Vec::new();
next_nodes.extend(current_nodes.iter()); next_nodes.extend(current_nodes.iter());
next_nodes.extend(current_nodes.iter().flat_map(|n| self.get_child(n, i))); next_nodes.extend(current_nodes.iter().flat_map(|n| self.get_child(n, i)));
@@ -206,39 +248,55 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
} }
#[allow(dead_code)] #[allow(dead_code)]
fn nesting_invariant(_: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) { fn nesting_invariant(
_: &CoverTree<T, F, D>,
nodes: &Vec<&Node<T>>,
next_nodes: &Vec<&Node<T>>,
_: i8,
) {
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n)); let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n)); let next_nodes_set: HashSet<&Node<T>> =
HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
for n in nodes_set.iter() { for n in nodes_set.iter() {
assert!(next_nodes_set.contains(n), "Nesting invariant of the cover tree is not satisfied. Set of nodes [{:?}] is not a subset of [{:?}]", nodes_set, next_nodes_set); assert!(next_nodes_set.contains(n), "Nesting invariant of the cover tree is not satisfied. Set of nodes [{:?}] is not a subset of [{:?}]", nodes_set, next_nodes_set);
} }
} }
#[allow(dead_code)] #[allow(dead_code)]
fn covering_tree(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) { fn covering_tree(
tree: &CoverTree<T, F, D>,
nodes: &Vec<&Node<T>>,
next_nodes: &Vec<&Node<T>>,
i: i8,
) {
let mut p_selected: Vec<&Node<T>> = Vec::new(); let mut p_selected: Vec<&Node<T>> = Vec::new();
for p in next_nodes { for p in next_nodes {
for q in nodes { for q in nodes {
if tree.distance.distance(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) { if tree.distance.distance(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
p_selected.push(*p); p_selected.push(*p);
} }
} }
let c = p_selected.iter().filter(|q| p.parent.map(|p| q.index == p).unwrap_or(false)).count(); let c = p_selected
.iter()
.filter(|q| p.parent.map(|p| q.index == p).unwrap_or(false))
.count();
assert!(c <= 1); assert!(c <= 1);
} }
} }
#[allow(dead_code)] #[allow(dead_code)]
fn separation(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) { fn separation(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
for p in nodes { for p in nodes {
for q in nodes { for q in nodes {
if p != q { if p != q {
assert!(tree.distance.distance(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap())); assert!(
} tree.distance.distance(&p.data, &q.data)
} > tree.base.powf(F::from(i).unwrap())
} );
}
}
}
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
@@ -251,7 +309,7 @@ struct Node<T> {
index: NodeId, index: NodeId,
data: T, data: T,
children: HashMap<i8, NodeId>, children: HashMap<i8, NodeId>,
parent: Option<NodeId> parent: Option<NodeId>,
} }
impl<T> PartialEq for Node<T> { impl<T> PartialEq for Node<T> {
@@ -277,22 +335,22 @@ mod tests {
use super::*; use super::*;
struct SimpleDistance{} struct SimpleDistance {}
impl Distance<i32, f64> for SimpleDistance { impl Distance<i32, f64> for SimpleDistance {
fn distance(&self, a: &i32, b: &i32) -> f64 { fn distance(&self, a: &i32, b: &i32) -> f64 {
(a - b).abs() as f64 (a - b).abs() as f64
} }
} }
#[test] #[test]
fn cover_tree_test() { fn cover_tree_test() {
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9); let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut tree = CoverTree::new(data, SimpleDistance{}); let mut tree = CoverTree::new(data, SimpleDistance {});
for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) { for d in vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19] {
tree.insert(d); tree.insert(d);
} }
let mut nearest_3_to_5 = tree.find(&5, 3); let mut nearest_3_to_5 = tree.find(&5, 3);
nearest_3_to_5.sort(); nearest_3_to_5.sort();
@@ -307,13 +365,12 @@ mod tests {
} }
#[test] #[test]
fn test_invariants(){ fn test_invariants() {
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9); let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let tree = CoverTree::new(data, SimpleDistance{}); let tree = CoverTree::new(data, SimpleDistance {});
tree.check_invariant(CoverTree::nesting_invariant); tree.check_invariant(CoverTree::nesting_invariant);
tree.check_invariant(CoverTree::covering_tree); tree.check_invariant(CoverTree::covering_tree);
tree.check_invariant(CoverTree::separation); tree.check_invariant(CoverTree::separation);
} }
}
}
+40 -35
View File
@@ -1,53 +1,52 @@
use serde::{Deserialize, Serialize};
use std::cmp::{Ordering, PartialOrd}; use std::cmp::{Ordering, PartialOrd};
use std::marker::PhantomData; use std::marker::PhantomData;
use serde::{Serialize, Deserialize};
use crate::math::num::FloatExt;
use crate::math::distance::Distance;
use crate::algorithm::sort::heap_select::HeapSelect; use crate::algorithm::sort::heap_select::HeapSelect;
use crate::math::distance::Distance;
use crate::math::num::FloatExt;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct LinearKNNSearch<T, F: FloatExt, D: Distance<T, F>> { pub struct LinearKNNSearch<T, F: FloatExt, D: Distance<T, F>> {
distance: D, distance: D,
data: Vec<T>, data: Vec<T>,
f: PhantomData<F> f: PhantomData<F>,
} }
impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> { impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D>{ pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D> {
LinearKNNSearch{ LinearKNNSearch {
data: data, data: data,
distance: distance, distance: distance,
f: PhantomData f: PhantomData,
} }
} }
pub fn find(&self, from: &T, k: usize) -> Vec<usize> { pub fn find(&self, from: &T, k: usize) -> Vec<usize> {
if k < 1 || k > self.data.len() { if k < 1 || k > self.data.len() {
panic!("k should be >= 1 and <= length(data)"); panic!("k should be >= 1 and <= length(data)");
} }
let mut heap = HeapSelect::<KNNPoint<F>>::with_capacity(k); let mut heap = HeapSelect::<KNNPoint<F>>::with_capacity(k);
for _ in 0..k { for _ in 0..k {
heap.add(KNNPoint{ heap.add(KNNPoint {
distance: F::infinity(), distance: F::infinity(),
index: None index: None,
}); });
} }
for i in 0..self.data.len() { for i in 0..self.data.len() {
let d = self.distance.distance(&from, &self.data[i]); let d = self.distance.distance(&from, &self.data[i]);
let datum = heap.peek_mut(); let datum = heap.peek_mut();
if d < datum.distance { if d < datum.distance {
datum.distance = d; datum.distance = d;
datum.index = Some(i); datum.index = Some(i);
heap.heapify(); heap.heapify();
} }
} }
heap.sort(); heap.sort();
heap.get().into_iter().flat_map(|x| x.index).collect() heap.get().into_iter().flat_map(|x| x.index).collect()
} }
@@ -56,7 +55,7 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
#[derive(Debug)] #[derive(Debug)]
struct KNNPoint<F: FloatExt> { struct KNNPoint<F: FloatExt> {
distance: F, distance: F,
index: Option<usize> index: Option<usize>,
} }
impl<F: FloatExt> PartialOrd for KNNPoint<F> { impl<F: FloatExt> PartialOrd for KNNPoint<F> {
@@ -74,27 +73,33 @@ impl<F: FloatExt> PartialEq for KNNPoint<F> {
impl<F: FloatExt> Eq for KNNPoint<F> {} impl<F: FloatExt> Eq for KNNPoint<F> {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::math::distance::Distances; use crate::math::distance::Distances;
struct SimpleDistance{} struct SimpleDistance {}
impl Distance<i32, f64> for SimpleDistance { impl Distance<i32, f64> for SimpleDistance {
fn distance(&self, a: &i32, b: &i32) -> f64 { fn distance(&self, a: &i32, b: &i32) -> f64 {
(a - b).abs() as f64 (a - b).abs() as f64
} }
} }
#[test] #[test]
fn knn_find() { fn knn_find() {
let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance{}); let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3)); assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
let data2 = vec!(vec![1., 1.], vec![2., 2.], vec![3., 3.], vec![4., 4.], vec![5., 5.]); let data2 = vec![
vec![1., 1.],
vec![2., 2.],
vec![3., 3.],
vec![4., 4.],
vec![5., 5.],
];
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian()); let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
@@ -103,29 +108,29 @@ mod tests {
#[test] #[test]
fn knn_point_eq() { fn knn_point_eq() {
let point1 = KNNPoint{ let point1 = KNNPoint {
distance: 10., distance: 10.,
index: Some(0) index: Some(0),
}; };
let point2 = KNNPoint{ let point2 = KNNPoint {
distance: 100., distance: 100.,
index: Some(1) index: Some(1),
}; };
let point3 = KNNPoint{ let point3 = KNNPoint {
distance: 10., distance: 10.,
index: Some(2) index: Some(2),
}; };
let point_inf = KNNPoint{ let point_inf = KNNPoint {
distance: std::f64::INFINITY, distance: std::f64::INFINITY,
index: Some(3) index: Some(3),
}; };
assert!(point2 > point1); assert!(point2 > point1);
assert_eq!(point3, point1); assert_eq!(point3, point1);
assert_ne!(point3, point2); assert_ne!(point3, point2);
assert!(point_inf > point3 && point_inf > point2 && point_inf > point1); assert!(point_inf > point3 && point_inf > point2 && point_inf > point1);
} }
} }
+1 -1
View File
@@ -1,3 +1,3 @@
pub mod bbd_tree;
pub mod cover_tree; pub mod cover_tree;
pub mod linear_search; pub mod linear_search;
pub mod bbd_tree;
+31 -35
View File
@@ -5,21 +5,20 @@ pub struct HeapSelect<T: PartialOrd> {
k: usize, k: usize,
n: usize, n: usize,
sorted: bool, sorted: bool,
heap: Vec<T> heap: Vec<T>,
} }
impl<'a, T: PartialOrd> HeapSelect<T> { impl<'a, T: PartialOrd> HeapSelect<T> {
pub fn with_capacity(k: usize) -> HeapSelect<T> { pub fn with_capacity(k: usize) -> HeapSelect<T> {
HeapSelect{ HeapSelect {
k: k, k: k,
n: 0, n: 0,
sorted: false, sorted: false,
heap: Vec::<T>::new() heap: Vec::<T>::new(),
} }
} }
pub fn add(&mut self, element: T) { pub fn add(&mut self, element: T) {
self.sorted = false; self.sorted = false;
if self.n < self.k { if self.n < self.k {
self.heap.push(element); self.heap.push(element);
@@ -30,23 +29,23 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
} else { } else {
self.n += 1; self.n += 1;
if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) { if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) {
self.heap[0] = element; self.heap[0] = element;
} }
} }
} }
pub fn heapify(&mut self) { pub fn heapify(&mut self) {
let n = self.heap.len(); let n = self.heap.len();
for i in (0..=(n / 2 - 1)).rev() { for i in (0..=(n / 2 - 1)).rev() {
self.sift_down(i, n-1); self.sift_down(i, n - 1);
} }
} }
pub fn peek(&self) -> &T { pub fn peek(&self) -> &T {
return &self.heap[0]; return &self.heap[0];
} }
pub fn peek_mut(&mut self) -> &mut T { pub fn peek_mut(&mut self) -> &mut T {
return &mut self.heap[0]; return &mut self.heap[0];
} }
@@ -59,11 +58,10 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
} }
if self.heap[k] >= self.heap[j] { if self.heap[k] >= self.heap[j] {
break; break;
} }
self.heap.swap(k, j); self.heap.swap(k, j);
k = j; k = j;
} }
} }
pub fn get(self) -> Vec<T> { pub fn get(self) -> Vec<T> {
@@ -71,7 +69,7 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
} }
pub fn sort(&mut self) { pub fn sort(&mut self) {
HeapSelect::shuffle_sort(&mut self.heap, std::cmp::min(self.k,self.n)); HeapSelect::shuffle_sort(&mut self.heap, std::cmp::min(self.k, self.n));
} }
pub fn shuffle_sort(vec: &mut Vec<T>, n: usize) { pub fn shuffle_sort(vec: &mut Vec<T>, n: usize) {
@@ -80,10 +78,10 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
inc *= 3; inc *= 3;
inc += 1 inc += 1
} }
let len = n; let len = n;
while inc >= 1 { while inc >= 1 {
let mut i = inc; let mut i = inc;
while i < len { while i < len {
let mut j = i; let mut j = i;
while j >= inc && vec[j - inc] > vec[j] { while j >= inc && vec[j - inc] > vec[j] {
@@ -95,60 +93,58 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
inc /= 3 inc /= 3
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn with_capacity() { fn with_capacity() {
let heap = HeapSelect::<i32>::with_capacity(3); let heap = HeapSelect::<i32>::with_capacity(3);
assert_eq!(3, heap.k); assert_eq!(3, heap.k);
} }
#[test] #[test]
fn test_add() { fn test_add() {
let mut heap = HeapSelect::with_capacity(3); let mut heap = HeapSelect::with_capacity(3);
heap.add(333); heap.add(333);
heap.add(2); heap.add(2);
heap.add(13); heap.add(13);
heap.add(10); heap.add(10);
heap.add(40); heap.add(40);
heap.add(30); heap.add(30);
assert_eq!(6, heap.n); assert_eq!(6, heap.n);
assert_eq!(&10, heap.peek()); assert_eq!(&10, heap.peek());
assert_eq!(&10, heap.peek_mut()); assert_eq!(&10, heap.peek_mut());
} }
#[test] #[test]
fn test_add_ordered() { fn test_add_ordered() {
let mut heap = HeapSelect::with_capacity(3); let mut heap = HeapSelect::with_capacity(3);
heap.add(1.); heap.add(1.);
heap.add(2.); heap.add(2.);
heap.add(3.); heap.add(3.);
heap.add(4.); heap.add(4.);
heap.add(5.); heap.add(5.);
heap.add(6.); heap.add(6.);
let result = heap.get(); let result = heap.get();
assert_eq!(vec![2., 3., 1.], result); assert_eq!(vec![2., 3., 1.], result);
} }
#[test] #[test]
fn test_shuffle_sort() { fn test_shuffle_sort() {
let mut v1 = vec![10, 33, 22, 105, 12]; let mut v1 = vec![10, 33, 22, 105, 12];
let n = v1.len(); let n = v1.len();
HeapSelect::shuffle_sort(&mut v1, n); HeapSelect::shuffle_sort(&mut v1, n);
assert_eq!(vec![10, 12, 22, 33, 105], v1); assert_eq!(vec![10, 12, 22, 33, 105], v1);
let mut v2 = vec![10, 33, 22, 105, 12]; let mut v2 = vec![10, 33, 22, 105, 12];
HeapSelect::shuffle_sort(&mut v2, 3); HeapSelect::shuffle_sort(&mut v2, 3);
assert_eq!(vec![10, 22, 33, 105, 12], v2); assert_eq!(vec![10, 22, 33, 105, 12], v2);
let mut v3 = vec![4, 5, 3, 2, 1]; let mut v3 = vec![4, 5, 3, 2, 1];
HeapSelect::shuffle_sort(&mut v3, 3); HeapSelect::shuffle_sort(&mut v3, 3);
assert_eq!(vec![3, 4, 5, 2, 1], v3); assert_eq!(vec![3, 4, 5, 2, 1], v3);
} }
}
}
+1 -1
View File
@@ -1,2 +1,2 @@
pub mod heap_select; pub mod heap_select;
pub mod quick_sort; pub mod quick_sort;
+27 -22
View File
@@ -5,13 +5,12 @@ pub trait QuickArgSort {
} }
impl<T: Float> QuickArgSort for Vec<T> { impl<T: Float> QuickArgSort for Vec<T> {
fn quick_argsort(&mut self) -> Vec<usize> { fn quick_argsort(&mut self) -> Vec<usize> {
let stack_size = 64; let stack_size = 64;
let mut jstack = -1; let mut jstack = -1;
let mut l = 0; let mut l = 0;
let mut istack = vec![0; stack_size]; let mut istack = vec![0; stack_size];
let mut ir = self.len() - 1; let mut ir = self.len() - 1;
let mut index: Vec<usize> = (0..self.len()).collect(); let mut index: Vec<usize> = (0..self.len()).collect();
loop { loop {
@@ -19,21 +18,21 @@ impl<T: Float> QuickArgSort for Vec<T> {
for j in l + 1..=ir { for j in l + 1..=ir {
let a = self[j]; let a = self[j];
let b = index[j]; let b = index[j];
let mut i: i32 = (j - 1) as i32; let mut i: i32 = (j - 1) as i32;
while i >= l as i32 { while i >= l as i32 {
if self[i as usize] <= a { if self[i as usize] <= a {
break; break;
} }
self[(i + 1) as usize] = self[i as usize]; self[(i + 1) as usize] = self[i as usize];
index[(i + 1) as usize] = index[i as usize]; index[(i + 1) as usize] = index[i as usize];
i -= 1; i -= 1;
} }
self[(i + 1) as usize] = a; self[(i + 1) as usize] = a;
index[(i + 1) as usize] = b; index[(i + 1) as usize] = b;
} }
if jstack < 0 { if jstack < 0 {
break; break;
} }
ir = istack[jstack as usize]; ir = istack[jstack as usize];
jstack -= 1; jstack -= 1;
l = istack[jstack as usize]; l = istack[jstack as usize];
@@ -66,7 +65,7 @@ impl<T: Float> QuickArgSort for Vec<T> {
} }
} }
loop { loop {
j -=1; j -= 1;
if self[j] <= a { if self[j] <= a {
break; break;
} }
@@ -81,7 +80,7 @@ impl<T: Float> QuickArgSort for Vec<T> {
self[j] = a; self[j] = a;
index[l + 1] = index[j]; index[l + 1] = index[j];
index[j] = b; index[j] = b;
jstack += 2; jstack += 2;
if jstack >= 64 { if jstack >= 64 {
panic!("stack size is too small."); panic!("stack size is too small.");
@@ -95,7 +94,7 @@ impl<T: Float> QuickArgSort for Vec<T> {
istack[jstack as usize] = j - 1; istack[jstack as usize] = j - 1;
istack[jstack as usize - 1] = l; istack[jstack as usize - 1] = l;
l = i; l = i;
} }
} }
} }
@@ -104,15 +103,21 @@ impl<T: Float> QuickArgSort for Vec<T> {
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn with_capacity() { fn with_capacity() {
let mut arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8]; let mut arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
assert_eq!(vec![1, 2, 0, 3, 5, 7, 6, 8, 4], arr1.quick_argsort()); assert_eq!(vec![1, 2, 0, 3, 5, 7, 6, 8, 4], arr1.quick_argsort());
let mut arr2 = vec![0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6, 1.0, 1.3, 1.4]; let mut arr2 = vec![
assert_eq!(vec![9, 7, 1, 8, 0, 2, 4, 3, 6, 5, 17, 18, 15, 13, 19, 10, 14, 11, 12, 16], arr2.quick_argsort()); 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6,
1.0, 1.3, 1.4,
];
assert_eq!(
vec![9, 7, 1, 8, 0, 2, 4, 3, 6, 5, 17, 18, 15, 13, 19, 10, 14, 11, 12, 16],
arr2.quick_argsort()
);
} }
} }
+54 -57
View File
@@ -1,40 +1,41 @@
extern crate rand; extern crate rand;
use rand::Rng; use rand::Rng;
use std::iter::Sum;
use std::fmt::Debug; use std::fmt::Debug;
use std::iter::Sum;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt; use crate::algorithm::neighbour::bbd_tree::BBDTree;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::distance::euclidian::*; use crate::math::distance::euclidian::*;
use crate::algorithm::neighbour::bbd_tree::BBDTree; use crate::math::num::FloatExt;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct KMeans<T: FloatExt> { pub struct KMeans<T: FloatExt> {
k: usize, k: usize,
y: Vec<usize>, y: Vec<usize>,
size: Vec<usize>, size: Vec<usize>,
distortion: T, distortion: T,
centroids: Vec<Vec<T>> centroids: Vec<Vec<T>>,
} }
impl<T: FloatExt> PartialEq for KMeans<T> { impl<T: FloatExt> PartialEq for KMeans<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.k != other.k || if self.k != other.k
self.size != other.size || || self.size != other.size
self.centroids.len() != other.centroids.len() { || self.centroids.len() != other.centroids.len()
{
false false
} else { } else {
let n_centroids = self.centroids.len(); let n_centroids = self.centroids.len();
for i in 0..n_centroids{ for i in 0..n_centroids {
if self.centroids[i].len() != other.centroids[i].len(){ if self.centroids[i].len() != other.centroids[i].len() {
return false return false;
} }
for j in 0..self.centroids[i].len() { for j in 0..self.centroids[i].len() {
if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() { if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() {
return false return false;
} }
} }
} }
@@ -44,21 +45,18 @@ impl<T: FloatExt> PartialEq for KMeans<T> {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct KMeansParameters { pub struct KMeansParameters {
pub max_iter: usize pub max_iter: usize,
} }
impl Default for KMeansParameters { impl Default for KMeansParameters {
fn default() -> Self { fn default() -> Self {
KMeansParameters { KMeansParameters { max_iter: 100 }
max_iter: 100 }
}
}
} }
impl<T: FloatExt + Sum> KMeans<T>{ impl<T: FloatExt + Sum> KMeans<T> {
pub fn new<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans<T> { pub fn new<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans<T> {
let bbd = BBDTree::new(data); let bbd = BBDTree::new(data);
if k < 2 { if k < 2 {
@@ -66,11 +64,14 @@ impl<T: FloatExt + Sum> KMeans<T>{
} }
if parameters.max_iter <= 0 { if parameters.max_iter <= 0 {
panic!("Invalid maximum number of iterations: {}", parameters.max_iter); panic!(
"Invalid maximum number of iterations: {}",
parameters.max_iter
);
} }
let (n, d) = data.shape(); let (n, d) = data.shape();
let mut distortion = T::max_value(); let mut distortion = T::max_value();
let mut y = KMeans::kmeans_plus_plus(data, k); let mut y = KMeans::kmeans_plus_plus(data, k);
let mut size = vec![0; k]; let mut size = vec![0; k];
@@ -90,10 +91,10 @@ impl<T: FloatExt + Sum> KMeans<T>{
for j in 0..d { for j in 0..d {
centroids[i][j] = centroids[i][j] / T::from(size[i]).unwrap(); centroids[i][j] = centroids[i][j] / T::from(size[i]).unwrap();
} }
} }
let mut sums = vec![vec![T::zero(); d]; k]; let mut sums = vec![vec![T::zero(); d]; k];
for _ in 1..= parameters.max_iter { for _ in 1..=parameters.max_iter {
let dist = bbd.clustering(&centroids, &mut sums, &mut size, &mut y); let dist = bbd.clustering(&centroids, &mut sums, &mut size, &mut y);
for i in 0..k { for i in 0..k {
if size[i] > 0 { if size[i] > 0 {
@@ -108,48 +109,46 @@ impl<T: FloatExt + Sum> KMeans<T>{
} else { } else {
distortion = dist; distortion = dist;
} }
}
}
KMeans{ KMeans {
k: k, k: k,
y: y, y: y,
size: size, size: size,
distortion: distortion, distortion: distortion,
centroids: centroids centroids: centroids,
} }
} }
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let (n, _) = x.shape(); let (n, _) = x.shape();
let mut result = M::zeros(1, n); let mut result = M::zeros(1, n);
for i in 0..n { for i in 0..n {
let mut min_dist = T::max_value(); let mut min_dist = T::max_value();
let mut best_cluster = 0; let mut best_cluster = 0;
for j in 0..self.k { for j in 0..self.k {
let dist = Euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]); let dist = Euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
if dist < min_dist { if dist < min_dist {
min_dist = dist; min_dist = dist;
best_cluster = j; best_cluster = j;
} }
} }
result.set(0, i, T::from(best_cluster).unwrap()); result.set(0, i, T::from(best_cluster).unwrap());
} }
result.to_row_vector() result.to_row_vector()
} }
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize>{ fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let (n, _) = data.shape(); let (n, _) = data.shape();
let mut y = vec![0; n]; let mut y = vec![0; n];
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n)); let mut centroid = data.get_row_as_vec(rng.gen_range(0, n));
let mut d = vec![T::max_value(); n]; let mut d = vec![T::max_value(); n];
// pick the next center // pick the next center
for j in 1..k { for j in 1..k {
// Loop over the samples and compare them to the most recent center. Store // Loop over the samples and compare them to the most recent center. Store
@@ -157,7 +156,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
for i in 0..n { for i in 0..n {
// compute the distance between this sample and the current center // compute the distance between this sample and the current center
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), &centroid); let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), &centroid);
if dist < d[i] { if dist < d[i] {
d[i] = dist; d[i] = dist;
y[i] = j - 1; y[i] = j - 1;
@@ -165,7 +164,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
} }
let mut sum: T = T::zero(); let mut sum: T = T::zero();
for i in d.iter(){ for i in d.iter() {
sum = sum + *i; sum = sum + *i;
} }
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum; let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
@@ -183,8 +182,8 @@ impl<T: FloatExt + Sum> KMeans<T>{
for i in 0..n { for i in 0..n {
// compute the distance between this sample and the current center // compute the distance between this sample and the current center
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), &centroid); let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), &centroid);
if dist < d[i] { if dist < d[i] {
d[i] = dist; d[i] = dist;
y[i] = k - 1; y[i] = k - 1;
@@ -193,17 +192,15 @@ impl<T: FloatExt + Sum> KMeans<T>{
y y
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test] #[test]
fn fit_predict_iris() { fn fit_predict_iris() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2],
@@ -224,7 +221,8 @@ mod tests {
&[6.3, 3.3, 4.7, 1.6], &[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4]]); &[5.2, 2.7, 3.9, 1.4],
]);
let kmeans = KMeans::new(&x, 2, Default::default()); let kmeans = KMeans::new(&x, 2, Default::default());
@@ -232,12 +230,11 @@ mod tests {
for i in 0..y.len() { for i in 0..y.len() {
assert_eq!(y[i] as usize, kmeans.y[i]); assert_eq!(y[i] as usize, kmeans.y[i]);
} }
} }
#[test] #[test]
fn serde() { fn serde() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2],
@@ -258,14 +255,14 @@ mod tests {
&[6.3, 3.3, 4.7, 1.6], &[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4]]); &[5.2, 2.7, 3.9, 1.4],
]);
let kmeans = KMeans::new(&x, 2, Default::default()); let kmeans = KMeans::new(&x, 2, Default::default());
let deserialized_kmeans: KMeans<f64> = serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap(); let deserialized_kmeans: KMeans<f64> =
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
assert_eq!(kmeans, deserialized_kmeans); assert_eq!(kmeans, deserialized_kmeans);
} }
}
}
+1 -1
View File
@@ -1 +1 @@
pub mod kmeans; pub mod kmeans;
+12 -7
View File
@@ -1,13 +1,18 @@
use num_traits::{Num, ToPrimitive, FromPrimitive, Zero, One}; use ndarray::ScalarOperand;
use ndarray::{ScalarOperand}; use num_traits::{FromPrimitive, Num, One, ToPrimitive, Zero};
use std::hash::Hash;
use std::fmt::Debug; use std::fmt::Debug;
use std::hash::Hash;
pub trait AnyNumber: Num + ScalarOperand + ToPrimitive + FromPrimitive{} pub trait AnyNumber: Num + ScalarOperand + ToPrimitive + FromPrimitive {}
pub trait Nominal: PartialEq + Zero + One + Eq + Hash + ToPrimitive + FromPrimitive + Debug + 'static + Clone{}
pub trait Nominal:
PartialEq + Zero + One + Eq + Hash + ToPrimitive + FromPrimitive + Debug + 'static + Clone
{
}
impl<T> AnyNumber for T where T: Num + ScalarOperand + ToPrimitive + FromPrimitive {} impl<T> AnyNumber for T where T: Num + ScalarOperand + ToPrimitive + FromPrimitive {}
impl<T> Nominal for T where T: PartialEq + Zero + One + Eq + Hash + ToPrimitive + Debug + FromPrimitive + 'static + Clone {} impl<T> Nominal for T where
T: PartialEq + Zero + One + Eq + Hash + ToPrimitive + Debug + FromPrimitive + 'static + Clone
{
}
+1 -1
View File
@@ -1 +1 @@
pub mod pca; pub mod pca;
+246 -184
View File
@@ -1,52 +1,51 @@
use std::fmt::Debug; use std::fmt::Debug;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::linalg::Matrix;
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
use crate::linalg::{Matrix};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct PCA<T: FloatExt, M: Matrix<T>> { pub struct PCA<T: FloatExt, M: Matrix<T>> {
eigenvectors: M, eigenvectors: M,
eigenvalues: Vec<T>, eigenvalues: Vec<T>,
projection: M, projection: M,
mu: Vec<T>, mu: Vec<T>,
pmu: Vec<T> pmu: Vec<T>,
} }
impl<T: FloatExt, M: Matrix<T>> PartialEq for PCA<T, M> { impl<T: FloatExt, M: Matrix<T>> PartialEq for PCA<T, M> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.eigenvectors != other.eigenvectors || if self.eigenvectors != other.eigenvectors
self.eigenvalues.len() != other.eigenvalues.len() { || self.eigenvalues.len() != other.eigenvalues.len()
return false {
return false;
} else { } else {
for i in 0..self.eigenvalues.len() { for i in 0..self.eigenvalues.len() {
if (self.eigenvalues[i] - other.eigenvalues[i]).abs() > T::epsilon() { if (self.eigenvalues[i] - other.eigenvalues[i]).abs() > T::epsilon() {
return false return false;
} }
} }
return true return true;
} }
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PCAParameters { pub struct PCAParameters {
use_correlation_matrix: bool use_correlation_matrix: bool,
} }
impl Default for PCAParameters { impl Default for PCAParameters {
fn default() -> Self { fn default() -> Self {
PCAParameters { PCAParameters {
use_correlation_matrix: false use_correlation_matrix: false,
} }
} }
} }
impl<T: FloatExt, M: Matrix<T>> PCA<T, M> { impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA<T, M> { pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA<T, M> {
let (m, n) = data.shape(); let (m, n) = data.shape();
let mu = data.column_mean(); let mu = data.column_mean();
@@ -62,8 +61,7 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
let mut eigenvalues; let mut eigenvalues;
let mut eigenvectors; let mut eigenvectors;
if m > n && !parameters.use_correlation_matrix{ if m > n && !parameters.use_correlation_matrix {
let svd = x.svd(); let svd = x.svd();
eigenvalues = svd.s; eigenvalues = svd.s;
for i in 0..eigenvalues.len() { for i in 0..eigenvalues.len() {
@@ -84,11 +82,11 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
for i in 0..n { for i in 0..n {
for j in 0..=i { for j in 0..=i {
cov.div_element_mut(i, j, T::from(m).unwrap()); cov.div_element_mut(i, j, T::from(m).unwrap());
cov.set(j, i, cov.get(i, j)); cov.set(j, i, cov.get(i, j));
} }
} }
if parameters.use_correlation_matrix { if parameters.use_correlation_matrix {
let mut sd = vec![T::zero(); n]; let mut sd = vec![T::zero(); n];
for i in 0..n { for i in 0..n {
@@ -114,16 +112,14 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
} }
} }
} else { } else {
let evd = cov.evd(true); let evd = cov.evd(true);
eigenvalues = evd.d; eigenvalues = evd.d;
eigenvectors = evd.V; eigenvectors = evd.V;
} }
} }
let mut projection = M::zeros(n_components, n); let mut projection = M::zeros(n_components, n);
for i in 0..n { for i in 0..n {
for j in 0..n_components { for j in 0..n_components {
@@ -143,7 +139,7 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
eigenvalues: eigenvalues, eigenvalues: eigenvalues,
projection: projection.transpose(), projection: projection.transpose(),
mu: mu, mu: mu,
pmu: pmu pmu: pmu,
} }
} }
@@ -151,7 +147,11 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
let (nrows, ncols) = x.shape(); let (nrows, ncols) = x.shape();
let (_, n_components) = self.projection.shape(); let (_, n_components) = self.projection.shape();
if ncols != self.mu.len() { if ncols != self.mu.len() {
panic!("Invalid input vector size: {}, expected: {}", ncols, self.mu.len()); panic!(
"Invalid input vector size: {}, expected: {}",
ncols,
self.mu.len()
);
} }
let mut x_transformed = x.dot(&self.projection); let mut x_transformed = x.dot(&self.projection);
@@ -162,12 +162,11 @@ impl<T: FloatExt, M: Matrix<T>> PCA<T, M> {
} }
x_transformed x_transformed
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
fn us_arrests_data() -> DenseMatrix<f64> { fn us_arrests_data() -> DenseMatrix<f64> {
@@ -221,173 +220,236 @@ mod tests {
&[4.0, 145.0, 73.0, 26.2], &[4.0, 145.0, 73.0, 26.2],
&[5.7, 81.0, 39.0, 9.3], &[5.7, 81.0, 39.0, 9.3],
&[2.6, 53.0, 66.0, 10.8], &[2.6, 53.0, 66.0, 10.8],
&[6.8, 161.0, 60.0, 15.6]]) &[6.8, 161.0, 60.0, 15.6],
])
} }
#[test] #[test]
fn decompose_covariance() { fn decompose_covariance() {
let us_arrests = us_arrests_data();
let us_arrests = us_arrests_data();
let expected_eigenvectors = DenseMatrix::from_array(&[
&[-0.0417043206282872, -0.0448216562696701, -0.0798906594208108, -0.994921731246978],
&[-0.995221281426497, -0.058760027857223, 0.0675697350838043, 0.0389382976351601],
&[-0.0463357461197108, 0.97685747990989, 0.200546287353866, -0.0581691430589319],
&[-0.075155500585547, 0.200718066450337, -0.974080592182491, 0.0723250196376097]
]);
let expected_projection = DenseMatrix::from_array(&[ let expected_eigenvectors = DenseMatrix::from_array(&[
&[-64.8022, -11.448, 2.4949, -2.4079], &[
&[-92.8275, -17.9829, -20.1266, 4.094], -0.0417043206282872,
&[-124.0682, 8.8304, 1.6874, 4.3537], -0.0448216562696701,
&[-18.34, -16.7039, -0.2102, 0.521], -0.0798906594208108,
&[-107.423, 22.5201, -6.7459, 2.8118], -0.994921731246978,
&[-34.976, 13.7196, -12.2794, 1.7215], ],
&[60.8873, 12.9325, 8.4207, 0.6999], &[
&[-66.731, 1.3538, 11.281, 3.728], -0.995221281426497,
&[-165.2444, 6.2747, 2.9979, -1.2477], -0.058760027857223,
&[-40.5352, -7.2902, -3.6095, -7.3437], 0.0675697350838043,
&[123.5361, 24.2912, -3.7244, -3.4728], 0.0389382976351601,
&[51.797, -9.4692, 1.5201, 3.3478], ],
&[-78.9921, 12.8971, 5.8833, -0.3676], &[
&[57.551, 2.8463, -3.7382, -1.6494], -0.0463357461197108,
&[115.5868, -3.3421, 0.654, 0.8695], 0.97685747990989,
&[55.7897, 3.1572, -0.3844, -0.6528], 0.200546287353866,
&[62.3832, -10.6733, -2.2371, -3.8762], -0.0581691430589319,
&[-78.2776, -4.2949, 3.8279, -4.4836], ],
&[89.261, -11.4878, 4.6924, 2.1162], &[
&[-129.3301, -5.007, 2.3472, 1.9283], -0.075155500585547,
&[21.2663, 19.4502, 7.5071, 1.0348], 0.200718066450337,
&[-85.4515, 5.9046, -6.4643, -0.499], -0.974080592182491,
&[98.9548, 5.2096, -0.0066, 0.7319], 0.0723250196376097,
&[-86.8564, -27.4284, 5.0034, -3.8798], ],
&[-7.9863, 5.2756, -5.5006, -0.6794], ]);
&[62.4836, -9.5105, -1.8384, -0.2459],
&[69.0965, -0.2112, -0.468, 0.6566],
&[-83.6136, 15.1022, -15.8887, -0.3342],
&[114.7774, -4.7346, 2.2824, 0.9359],
&[10.8157, 23.1373, 6.3102, -1.6124],
&[-114.8682, -0.3365, -2.2613, 1.3812],
&[-84.2942, 15.924, 4.7213, -0.892],
&[-164.3255, -31.0966, 11.6962, 2.1112],
&[127.4956, -16.135, 1.3118, 2.301],
&[50.0868, 12.2793, -1.6573, -2.0291],
&[19.6937, 3.3701, 0.4531, 0.1803],
&[11.1502, 3.8661, -8.13, 2.914],
&[64.6891, 8.9115, 3.2065, -1.8749],
&[-3.064, 18.374, 17.47, 2.3083],
&[-107.2811, -23.5361, 2.0328, -1.2517],
&[86.1067, -16.5979, -1.3144, 1.2523],
&[-17.5063, -6.5066, -6.1001, -3.9229],
&[-31.2911, 12.985, 0.3934, -4.242],
&[49.9134, 17.6485, -1.7882, 1.8677],
&[124.7145, -27.3136, -4.8028, 2.005],
&[14.8174, -1.7526, -1.0454, -1.1738],
&[25.0758, 9.968, -4.7811, 2.6911],
&[91.5446, -22.9529, 0.402, -0.7369],
&[118.1763, 5.5076, 2.7113, -0.205],
&[10.4345, -5.9245, 3.7944, 0.5179]
]);
let expected_eigenvalues: Vec<f64> = vec![343544.6277001563, 9897.625949808047, 2063.519887011604, 302.04806302399646]; let expected_projection = DenseMatrix::from_array(&[
&[-64.8022, -11.448, 2.4949, -2.4079],
let pca = PCA::new(&us_arrests, 4, Default::default()); &[-92.8275, -17.9829, -20.1266, 4.094],
&[-124.0682, 8.8304, 1.6874, 4.3537],
&[-18.34, -16.7039, -0.2102, 0.521],
&[-107.423, 22.5201, -6.7459, 2.8118],
&[-34.976, 13.7196, -12.2794, 1.7215],
&[60.8873, 12.9325, 8.4207, 0.6999],
&[-66.731, 1.3538, 11.281, 3.728],
&[-165.2444, 6.2747, 2.9979, -1.2477],
&[-40.5352, -7.2902, -3.6095, -7.3437],
&[123.5361, 24.2912, -3.7244, -3.4728],
&[51.797, -9.4692, 1.5201, 3.3478],
&[-78.9921, 12.8971, 5.8833, -0.3676],
&[57.551, 2.8463, -3.7382, -1.6494],
&[115.5868, -3.3421, 0.654, 0.8695],
&[55.7897, 3.1572, -0.3844, -0.6528],
&[62.3832, -10.6733, -2.2371, -3.8762],
&[-78.2776, -4.2949, 3.8279, -4.4836],
&[89.261, -11.4878, 4.6924, 2.1162],
&[-129.3301, -5.007, 2.3472, 1.9283],
&[21.2663, 19.4502, 7.5071, 1.0348],
&[-85.4515, 5.9046, -6.4643, -0.499],
&[98.9548, 5.2096, -0.0066, 0.7319],
&[-86.8564, -27.4284, 5.0034, -3.8798],
&[-7.9863, 5.2756, -5.5006, -0.6794],
&[62.4836, -9.5105, -1.8384, -0.2459],
&[69.0965, -0.2112, -0.468, 0.6566],
&[-83.6136, 15.1022, -15.8887, -0.3342],
&[114.7774, -4.7346, 2.2824, 0.9359],
&[10.8157, 23.1373, 6.3102, -1.6124],
&[-114.8682, -0.3365, -2.2613, 1.3812],
&[-84.2942, 15.924, 4.7213, -0.892],
&[-164.3255, -31.0966, 11.6962, 2.1112],
&[127.4956, -16.135, 1.3118, 2.301],
&[50.0868, 12.2793, -1.6573, -2.0291],
&[19.6937, 3.3701, 0.4531, 0.1803],
&[11.1502, 3.8661, -8.13, 2.914],
&[64.6891, 8.9115, 3.2065, -1.8749],
&[-3.064, 18.374, 17.47, 2.3083],
&[-107.2811, -23.5361, 2.0328, -1.2517],
&[86.1067, -16.5979, -1.3144, 1.2523],
&[-17.5063, -6.5066, -6.1001, -3.9229],
&[-31.2911, 12.985, 0.3934, -4.242],
&[49.9134, 17.6485, -1.7882, 1.8677],
&[124.7145, -27.3136, -4.8028, 2.005],
&[14.8174, -1.7526, -1.0454, -1.1738],
&[25.0758, 9.968, -4.7811, 2.6911],
&[91.5446, -22.9529, 0.402, -0.7369],
&[118.1763, 5.5076, 2.7113, -0.205],
&[10.4345, -5.9245, 3.7944, 0.5179],
]);
assert!(pca.eigenvectors.abs().approximate_eq(&expected_eigenvectors.abs(), 1e-4)); let expected_eigenvalues: Vec<f64> = vec![
343544.6277001563,
for i in 0..pca.eigenvalues.len() { 9897.625949808047,
assert_eq!(pca.eigenvalues[i].abs(), expected_eigenvalues[i].abs()); 2063.519887011604,
} 302.04806302399646,
];
let us_arrests_t = pca.transform(&us_arrests); let pca = PCA::new(&us_arrests, 4, Default::default());
assert!(us_arrests_t.abs().approximate_eq(&expected_projection.abs(), 1e-4)); assert!(pca
.eigenvectors
.abs()
.approximate_eq(&expected_eigenvectors.abs(), 1e-4));
for i in 0..pca.eigenvalues.len() {
assert_eq!(pca.eigenvalues[i].abs(), expected_eigenvalues[i].abs());
}
let us_arrests_t = pca.transform(&us_arrests);
assert!(us_arrests_t
.abs()
.approximate_eq(&expected_projection.abs(), 1e-4));
} }
#[test] #[test]
fn decompose_correlation() { fn decompose_correlation() {
let us_arrests = us_arrests_data();
let us_arrests = us_arrests_data();
let expected_eigenvectors = DenseMatrix::from_array(&[
&[0.124288601688222, -0.0969866877028367, 0.0791404742697482, -0.150572299008293],
&[0.00706888610512014, -0.00227861130898090, 0.00325028101296307, 0.00901099154845273],
&[0.0194141494466002, 0.060910660326921, 0.0263806464184195, -0.0093429458365566],
&[0.0586084532558777, 0.0180450999787168, -0.0881962972508558, -0.0096011588898465]
]);
let expected_projection = DenseMatrix::from_array(&[ let expected_eigenvectors = DenseMatrix::from_array(&[
&[0.9856, -1.1334, 0.4443, -0.1563], &[
&[1.9501, -1.0732, -2.04, 0.4386], 0.124288601688222,
&[1.7632, 0.746, -0.0548, 0.8347], -0.0969866877028367,
&[-0.1414, -1.1198, -0.1146, 0.1828], 0.0791404742697482,
&[2.524, 1.5429, -0.5986, 0.342], -0.150572299008293,
&[1.5146, 0.9876, -1.095, -0.0015], ],
&[-1.3586, 1.0889, 0.6433, 0.1185], &[
&[0.0477, 0.3254, 0.7186, 0.882], 0.00706888610512014,
&[3.013, -0.0392, 0.5768, 0.0963], -0.00227861130898090,
&[1.6393, -1.2789, 0.3425, -1.0768], 0.00325028101296307,
&[-0.9127, 1.5705, -0.0508, -0.9028], 0.00901099154845273,
&[-1.6398, -0.211, -0.2598, 0.4991], ],
&[1.3789, 0.6818, 0.6775, 0.122], &[
&[-0.5055, 0.1516, -0.2281, -0.4247], 0.0194141494466002,
&[-2.2536, 0.1041, -0.1646, -0.0176], 0.060910660326921,
&[-0.7969, 0.2702, -0.0256, -0.2065], 0.0263806464184195,
&[-0.7509, -0.9584, 0.0284, -0.6706], -0.0093429458365566,
&[1.5648, -0.8711, 0.7835, -0.4547], ],
&[-2.3968, -0.3764, 0.0657, 0.3305], &[
&[1.7634, -0.4277, 0.1573, 0.5591], 0.0586084532558777,
&[-0.4862, 1.4745, 0.6095, 0.1796], 0.0180450999787168,
&[2.1084, 0.1554, -0.3849, -0.1024], -0.0881962972508558,
&[-1.6927, 0.6323, -0.1531, -0.0673], -0.0096011588898465,
&[0.9965, -2.3938, 0.7408, -0.2155], ],
&[0.6968, 0.2634, -0.3774, -0.2258], ]);
&[-1.1855, -0.5369, -0.2469, -0.1237],
&[-1.2656, 0.194, -0.1756, -0.0159],
&[2.8744, 0.7756, -1.1634, -0.3145],
&[-2.3839, 0.0181, -0.0369, 0.0331],
&[0.1816, 1.4495, 0.7645, -0.2434],
&[1.98, -0.1428, -0.1837, 0.3395],
&[1.6826, 0.8232, 0.6431, 0.0135],
&[1.1234, -2.228, 0.8636, 0.9544],
&[-2.9922, -0.5991, -0.3013, 0.254],
&[-0.226, 0.7422, 0.0311, -0.4739],
&[-0.3118, 0.2879, 0.0153, -0.0103],
&[0.0591, 0.5414, -0.9398, 0.2378],
&[-0.8884, 0.5711, 0.4006, -0.3591],
&[-0.8638, 1.492, 1.3699, 0.6136],
&[1.3207, -1.9334, 0.3005, 0.1315],
&[-1.9878, -0.8233, -0.3893, 0.1096],
&[0.9997, -0.8603, -0.1881, -0.6529],
&[1.3551, 0.4125, 0.4921, -0.6432],
&[-0.5506, 1.4715, -0.2937, 0.0823],
&[-2.8014, -1.4023, -0.8413, 0.1449],
&[-0.0963, -0.1997, -0.0117, -0.2114],
&[-0.2169, 0.9701, -0.6249, 0.2208],
&[-2.1086, -1.4248, -0.1048, -0.1319],
&[-2.0797, 0.6113, 0.1389, -0.1841],
&[-0.6294, -0.321, 0.2407, 0.1667]
]);
let expected_eigenvalues: Vec<f64> = vec![2.480241579149493, 0.9897651525398419, 0.35656318058083064, 0.1734300877298357]; let expected_projection = DenseMatrix::from_array(&[
&[0.9856, -1.1334, 0.4443, -0.1563],
let pca = PCA::new(&us_arrests, 4, PCAParameters{use_correlation_matrix: true}); &[1.9501, -1.0732, -2.04, 0.4386],
&[1.7632, 0.746, -0.0548, 0.8347],
&[-0.1414, -1.1198, -0.1146, 0.1828],
&[2.524, 1.5429, -0.5986, 0.342],
&[1.5146, 0.9876, -1.095, -0.0015],
&[-1.3586, 1.0889, 0.6433, 0.1185],
&[0.0477, 0.3254, 0.7186, 0.882],
&[3.013, -0.0392, 0.5768, 0.0963],
&[1.6393, -1.2789, 0.3425, -1.0768],
&[-0.9127, 1.5705, -0.0508, -0.9028],
&[-1.6398, -0.211, -0.2598, 0.4991],
&[1.3789, 0.6818, 0.6775, 0.122],
&[-0.5055, 0.1516, -0.2281, -0.4247],
&[-2.2536, 0.1041, -0.1646, -0.0176],
&[-0.7969, 0.2702, -0.0256, -0.2065],
&[-0.7509, -0.9584, 0.0284, -0.6706],
&[1.5648, -0.8711, 0.7835, -0.4547],
&[-2.3968, -0.3764, 0.0657, 0.3305],
&[1.7634, -0.4277, 0.1573, 0.5591],
&[-0.4862, 1.4745, 0.6095, 0.1796],
&[2.1084, 0.1554, -0.3849, -0.1024],
&[-1.6927, 0.6323, -0.1531, -0.0673],
&[0.9965, -2.3938, 0.7408, -0.2155],
&[0.6968, 0.2634, -0.3774, -0.2258],
&[-1.1855, -0.5369, -0.2469, -0.1237],
&[-1.2656, 0.194, -0.1756, -0.0159],
&[2.8744, 0.7756, -1.1634, -0.3145],
&[-2.3839, 0.0181, -0.0369, 0.0331],
&[0.1816, 1.4495, 0.7645, -0.2434],
&[1.98, -0.1428, -0.1837, 0.3395],
&[1.6826, 0.8232, 0.6431, 0.0135],
&[1.1234, -2.228, 0.8636, 0.9544],
&[-2.9922, -0.5991, -0.3013, 0.254],
&[-0.226, 0.7422, 0.0311, -0.4739],
&[-0.3118, 0.2879, 0.0153, -0.0103],
&[0.0591, 0.5414, -0.9398, 0.2378],
&[-0.8884, 0.5711, 0.4006, -0.3591],
&[-0.8638, 1.492, 1.3699, 0.6136],
&[1.3207, -1.9334, 0.3005, 0.1315],
&[-1.9878, -0.8233, -0.3893, 0.1096],
&[0.9997, -0.8603, -0.1881, -0.6529],
&[1.3551, 0.4125, 0.4921, -0.6432],
&[-0.5506, 1.4715, -0.2937, 0.0823],
&[-2.8014, -1.4023, -0.8413, 0.1449],
&[-0.0963, -0.1997, -0.0117, -0.2114],
&[-0.2169, 0.9701, -0.6249, 0.2208],
&[-2.1086, -1.4248, -0.1048, -0.1319],
&[-2.0797, 0.6113, 0.1389, -0.1841],
&[-0.6294, -0.321, 0.2407, 0.1667],
]);
assert!(pca.eigenvectors.abs().approximate_eq(&expected_eigenvectors.abs(), 1e-4)); let expected_eigenvalues: Vec<f64> = vec![
2.480241579149493,
for i in 0..pca.eigenvalues.len() { 0.9897651525398419,
assert_eq!(pca.eigenvalues[i].abs(), expected_eigenvalues[i].abs()); 0.35656318058083064,
} 0.1734300877298357,
];
let us_arrests_t = pca.transform(&us_arrests); let pca = PCA::new(
&us_arrests,
4,
PCAParameters {
use_correlation_matrix: true,
},
);
assert!(us_arrests_t.abs().approximate_eq(&expected_projection.abs(), 1e-4)); assert!(pca
.eigenvectors
.abs()
.approximate_eq(&expected_eigenvectors.abs(), 1e-4));
for i in 0..pca.eigenvalues.len() {
assert_eq!(pca.eigenvalues[i].abs(), expected_eigenvalues[i].abs());
}
let us_arrests_t = pca.transform(&us_arrests);
assert!(us_arrests_t
.abs()
.approximate_eq(&expected_projection.abs(), 1e-4));
} }
#[test] #[test]
fn serde() { fn serde() {
let iris = DenseMatrix::from_array(&[ let iris = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2],
@@ -408,14 +470,14 @@ mod tests {
&[6.3, 3.3, 4.7, 1.6], &[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4]]); &[5.2, 2.7, 3.9, 1.4],
]);
let pca = PCA::new(&iris, 4, Default::default()); let pca = PCA::new(&iris, 4, Default::default());
let deserialized_pca: PCA<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap(); let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
assert_eq!(pca, deserialized_pca); assert_eq!(pca, deserialized_pca);
} }
}
}
+1 -1
View File
@@ -1,2 +1,2 @@
pub mod random_forest_classifier; pub mod random_forest_classifier;
pub mod random_forest_regressor; pub mod random_forest_regressor;
+84 -69
View File
@@ -4,43 +4,44 @@ use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use rand::Rng; use rand::Rng;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max}; use crate::math::num::FloatExt;
use crate::tree::decision_tree_classifier::{
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
};
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RandomForestClassifierParameters { pub struct RandomForestClassifierParameters {
pub criterion: SplitCriterion, pub criterion: SplitCriterion,
pub max_depth: Option<u16>, pub max_depth: Option<u16>,
pub min_samples_leaf: usize, pub min_samples_leaf: usize,
pub min_samples_split: usize, pub min_samples_split: usize,
pub n_trees: u16, pub n_trees: u16,
pub mtry: Option<usize> pub mtry: Option<usize>,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct RandomForestClassifier<T: FloatExt> { pub struct RandomForestClassifier<T: FloatExt> {
parameters: RandomForestClassifierParameters, parameters: RandomForestClassifierParameters,
trees: Vec<DecisionTreeClassifier<T>>, trees: Vec<DecisionTreeClassifier<T>>,
classes: Vec<T> classes: Vec<T>,
} }
impl<T: FloatExt> PartialEq for RandomForestClassifier<T> { impl<T: FloatExt> PartialEq for RandomForestClassifier<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.classes.len() != other.classes.len() || if self.classes.len() != other.classes.len() || self.trees.len() != other.trees.len() {
self.trees.len() != other.trees.len() { return false;
return false
} else { } else {
for i in 0..self.classes.len() { for i in 0..self.classes.len() {
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() { if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
return false return false;
} }
} }
for i in 0..self.trees.len() { for i in 0..self.trees.len() {
if self.trees[i] != other.trees[i] { if self.trees[i] != other.trees[i] {
return false return false;
} }
} }
true true
@@ -49,45 +50,54 @@ impl<T: FloatExt> PartialEq for RandomForestClassifier<T> {
} }
impl Default for RandomForestClassifierParameters { impl Default for RandomForestClassifierParameters {
fn default() -> Self { fn default() -> Self {
RandomForestClassifierParameters { RandomForestClassifierParameters {
criterion: SplitCriterion::Gini, criterion: SplitCriterion::Gini,
max_depth: None, max_depth: None,
min_samples_leaf: 1, min_samples_leaf: 1,
min_samples_split: 2, min_samples_split: 2,
n_trees: 100, n_trees: 100,
mtry: Option::None mtry: Option::None,
} }
} }
} }
impl<T: FloatExt> RandomForestClassifier<T> { impl<T: FloatExt> RandomForestClassifier<T> {
pub fn fit<M: Matrix<T>>(
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: RandomForestClassifierParameters) -> RandomForestClassifier<T> { x: &M,
y: &M::RowVector,
parameters: RandomForestClassifierParameters,
) -> RandomForestClassifier<T> {
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
let mut yi: Vec<usize> = vec![0; y_ncols]; let mut yi: Vec<usize> = vec![0; y_ncols];
let classes = y_m.unique(); let classes = y_m.unique();
for i in 0..y_ncols { for i in 0..y_ncols {
let yc = y_m.get(0, i); let yc = y_m.get(0, i);
yi[i] = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap();
} }
let mtry = parameters.mtry.unwrap_or((T::from(num_attributes).unwrap()).sqrt().floor().to_usize().unwrap());
let classes = y_m.unique(); let mtry = parameters.mtry.unwrap_or(
let k = classes.len(); (T::from(num_attributes).unwrap())
.sqrt()
.floor()
.to_usize()
.unwrap(),
);
let classes = y_m.unique();
let k = classes.len();
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new(); let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
for _ in 0..parameters.n_trees { for _ in 0..parameters.n_trees {
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k); let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k);
let params = DecisionTreeClassifierParameters{ let params = DecisionTreeClassifierParameters {
criterion: parameters.criterion.clone(), criterion: parameters.criterion.clone(),
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split min_samples_split: parameters.min_samples_split,
}; };
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params); let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
trees.push(tree); trees.push(tree);
@@ -96,13 +106,13 @@ impl<T: FloatExt> RandomForestClassifier<T> {
RandomForestClassifier { RandomForestClassifier {
parameters: parameters, parameters: parameters,
trees: trees, trees: trees,
classes classes,
} }
} }
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape(); let (n, _) = x.shape();
for i in 0..n { for i in 0..n {
@@ -110,20 +120,19 @@ impl<T: FloatExt> RandomForestClassifier<T> {
} }
result.to_row_vector() result.to_row_vector()
} }
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize { fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
let mut result = vec![0; self.classes.len()]; let mut result = vec![0; self.classes.len()];
for tree in self.trees.iter() { for tree in self.trees.iter() {
result[tree.predict_for_row(x, row)] += 1; result[tree.predict_for_row(x, row)] += 1;
} }
return which_max(&result) return which_max(&result);
}
}
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<usize> {
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<usize>{
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let class_weight = vec![1.; num_classes]; let class_weight = vec![1.; num_classes];
let nrows = y.len(); let nrows = y.len();
@@ -137,8 +146,8 @@ impl<T: FloatExt> RandomForestClassifier<T> {
nj += 1; nj += 1;
} }
} }
let size = ((nj as f64) / class_weight[l]) as usize; let size = ((nj as f64) / class_weight[l]) as usize;
for _ in 0..size { for _ in 0..size {
let xi: usize = rng.gen_range(0, nj); let xi: usize = rng.gen_range(0, nj);
samples[cj[xi]] += 1; samples[cj[xi]] += 1;
@@ -146,17 +155,15 @@ impl<T: FloatExt> RandomForestClassifier<T> {
} }
samples samples
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test] #[test]
fn fit_predict_iris() { fn fit_predict_iris() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2],
@@ -177,24 +184,30 @@ mod tests {
&[6.3, 3.3, 4.7, 1.6], &[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4]]); &[5.2, 2.7, 3.9, 1.4],
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; ]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let classifier = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters{ let classifier = RandomForestClassifier::fit(
criterion: SplitCriterion::Gini, &x,
max_depth: None, &y,
min_samples_leaf: 1, RandomForestClassifierParameters {
min_samples_split: 2, criterion: SplitCriterion::Gini,
n_trees: 1000, max_depth: None,
mtry: Option::None min_samples_leaf: 1,
}); min_samples_split: 2,
n_trees: 1000,
mtry: Option::None,
},
);
assert_eq!(y, classifier.predict(&x)); assert_eq!(y, classifier.predict(&x));
} }
#[test] #[test]
fn serde() { fn serde() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2],
@@ -215,15 +228,17 @@ mod tests {
&[6.3, 3.3, 4.7, 1.6], &[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4]]); &[5.2, 2.7, 3.9, 1.4],
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; ]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let forest = RandomForestClassifier::fit(&x, &y, Default::default()); let forest = RandomForestClassifier::fit(&x, &y, Default::default());
let deserialized_forest: RandomForestClassifier<f64> = bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap(); let deserialized_forest: RandomForestClassifier<f64> =
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
assert_eq!(forest, deserialized_forest); assert_eq!(forest, deserialized_forest);
} }
}
}
+138 -111
View File
@@ -4,47 +4,49 @@ use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use rand::Rng; use rand::Rng;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::tree::decision_tree_regressor::{DecisionTreeRegressor, DecisionTreeRegressorParameters}; use crate::math::num::FloatExt;
use crate::tree::decision_tree_regressor::{
DecisionTreeRegressor, DecisionTreeRegressorParameters,
};
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RandomForestRegressorParameters { pub struct RandomForestRegressorParameters {
pub max_depth: Option<u16>, pub max_depth: Option<u16>,
pub min_samples_leaf: usize, pub min_samples_leaf: usize,
pub min_samples_split: usize, pub min_samples_split: usize,
pub n_trees: usize, pub n_trees: usize,
pub mtry: Option<usize> pub mtry: Option<usize>,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct RandomForestRegressor<T: FloatExt> { pub struct RandomForestRegressor<T: FloatExt> {
parameters: RandomForestRegressorParameters, parameters: RandomForestRegressorParameters,
trees: Vec<DecisionTreeRegressor<T>> trees: Vec<DecisionTreeRegressor<T>>,
} }
impl Default for RandomForestRegressorParameters { impl Default for RandomForestRegressorParameters {
fn default() -> Self { fn default() -> Self {
RandomForestRegressorParameters { RandomForestRegressorParameters {
max_depth: None, max_depth: None,
min_samples_leaf: 1, min_samples_leaf: 1,
min_samples_split: 2, min_samples_split: 2,
n_trees: 10, n_trees: 10,
mtry: Option::None mtry: Option::None,
} }
} }
} }
impl<T: FloatExt> PartialEq for RandomForestRegressor<T> { impl<T: FloatExt> PartialEq for RandomForestRegressor<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.trees.len() != other.trees.len() { if self.trees.len() != other.trees.len() {
return false return false;
} else { } else {
for i in 0..self.trees.len() { for i in 0..self.trees.len() {
if self.trees[i] != other.trees[i] { if self.trees[i] != other.trees[i] {
return false return false;
} }
} }
true true
@@ -53,20 +55,25 @@ impl<T: FloatExt> PartialEq for RandomForestRegressor<T> {
} }
impl<T: FloatExt> RandomForestRegressor<T> { impl<T: FloatExt> RandomForestRegressor<T> {
pub fn fit<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
parameters: RandomForestRegressorParameters,
) -> RandomForestRegressor<T> {
let (n_rows, num_attributes) = x.shape();
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: RandomForestRegressorParameters) -> RandomForestRegressor<T> { let mtry = parameters
let (n_rows, num_attributes) = x.shape(); .mtry
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mtry = parameters.mtry.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new(); let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
for _ in 0..parameters.n_trees { for _ in 0..parameters.n_trees {
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows); let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows);
let params = DecisionTreeRegressorParameters{ let params = DecisionTreeRegressorParameters {
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split min_samples_split: parameters.min_samples_split,
}; };
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params); let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params);
trees.push(tree); trees.push(tree);
@@ -74,13 +81,13 @@ impl<T: FloatExt> RandomForestRegressor<T> {
RandomForestRegressor { RandomForestRegressor {
parameters: parameters, parameters: parameters,
trees: trees trees: trees,
} }
} }
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape(); let (n, _) = x.shape();
for i in 0..n { for i in 0..n {
@@ -88,23 +95,21 @@ impl<T: FloatExt> RandomForestRegressor<T> {
} }
result.to_row_vector() result.to_row_vector()
} }
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T { fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
let n_trees = self.trees.len(); let n_trees = self.trees.len();
let mut result = T::zero(); let mut result = T::zero();
for tree in self.trees.iter() { for tree in self.trees.iter() {
result = result + tree.predict_for_row(x, row); result = result + tree.predict_for_row(x, row);
} }
result / T::from(n_trees).unwrap() result / T::from(n_trees).unwrap()
}
}
fn sample_with_replacement(nrows: usize) -> Vec<usize> {
fn sample_with_replacement(nrows: usize) -> Vec<usize>{
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let mut samples = vec![0; nrows]; let mut samples = vec![0; nrows];
for _ in 0..nrows { for _ in 0..nrows {
@@ -113,116 +118,138 @@ impl<T: FloatExt> RandomForestRegressor<T> {
} }
samples samples
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
use ndarray::{arr1, arr2}; use ndarray::{arr1, arr2};
#[test] #[test]
fn fit_longley() { fn fit_longley() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323], &[234.289, 235.6, 159., 107.608, 1947., 60.323],
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122], &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171], &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187], &[284.599, 335.1, 165., 110.929, 1950., 61.187],
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221], &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639], &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989], &[365.385, 187., 354.7, 115.094, 1953., 64.989],
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761], &[363.112, 357.8, 335., 116.219, 1954., 63.761],
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019], &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857], &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169], &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513], &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655], &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; ]);
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let expected_y: Vec<f64> = vec![85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.]; let expected_y: Vec<f64> = vec![
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
];
let y_hat = RandomForestRegressor::fit(&x, &y, let y_hat = RandomForestRegressor::fit(
RandomForestRegressorParameters{max_depth: None, &x,
&y,
RandomForestRegressorParameters {
max_depth: None,
min_samples_leaf: 1, min_samples_leaf: 1,
min_samples_split: 2, min_samples_split: 2,
n_trees: 1000, n_trees: 1000,
mtry: Option::None}).predict(&x); mtry: Option::None,
},
)
.predict(&x);
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 1.0); assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
} }
} }
#[test] #[test]
fn my_fit_longley_ndarray() { fn my_fit_longley_ndarray() {
let x = arr2(&[ let x = arr2(&[
[ 234.289, 235.6, 159., 107.608, 1947., 60.323], [234.289, 235.6, 159., 107.608, 1947., 60.323],
[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122], [259.426, 232.5, 145.6, 108.632, 1948., 61.122],
[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171], [258.054, 368.2, 161.6, 109.773, 1949., 60.171],
[ 284.599, 335.1, 165., 110.929, 1950., 61.187], [284.599, 335.1, 165., 110.929, 1950., 61.187],
[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221], [328.975, 209.9, 309.9, 112.075, 1951., 63.221],
[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639], [346.999, 193.2, 359.4, 113.27, 1952., 63.639],
[ 365.385, 187., 354.7, 115.094, 1953., 64.989], [365.385, 187., 354.7, 115.094, 1953., 64.989],
[ 363.112, 357.8, 335., 116.219, 1954., 63.761], [363.112, 357.8, 335., 116.219, 1954., 63.761],
[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019], [397.469, 290.4, 304.8, 117.388, 1955., 66.019],
[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857], [419.18, 282.2, 285.7, 118.734, 1956., 67.857],
[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169], [442.769, 293.6, 279.8, 120.445, 1957., 68.169],
[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513], [444.546, 468.1, 263.7, 121.95, 1958., 66.513],
[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655], [482.704, 381.3, 255.2, 123.366, 1959., 68.655],
[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564], [502.601, 393.1, 251.4, 125.368, 1960., 69.564],
[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331], [518.173, 480.6, 257.2, 127.852, 1961., 69.331],
[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); [554.894, 400.7, 282.7, 130.081, 1962., 70.551],
let y = arr1(&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]); ]);
let y = arr1(&[
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
]);
let expected_y: Vec<f64> = vec![85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.]; let expected_y: Vec<f64> = vec![
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
];
let y_hat = RandomForestRegressor::fit(&x, &y, let y_hat = RandomForestRegressor::fit(
RandomForestRegressorParameters{max_depth: None, &x,
&y,
RandomForestRegressorParameters {
max_depth: None,
min_samples_leaf: 1, min_samples_leaf: 1,
min_samples_split: 2, min_samples_split: 2,
n_trees: 1000, n_trees: 1000,
mtry: Option::None}).predict(&x); mtry: Option::None,
},
)
.predict(&x);
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 1.0); assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
} }
} }
#[test] #[test]
fn serde() { fn serde() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323], &[234.289, 235.6, 159., 107.608, 1947., 60.323],
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122], &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171], &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187], &[284.599, 335.1, 165., 110.929, 1950., 61.187],
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221], &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639], &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989], &[365.385, 187., 354.7, 115.094, 1953., 64.989],
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761], &[363.112, 357.8, 335., 116.219, 1954., 63.761],
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019], &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857], &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169], &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513], &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655], &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; ]);
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let forest = RandomForestRegressor::fit(&x, &y, Default::default()); let forest = RandomForestRegressor::fit(&x, &y, Default::default());
let deserialized_forest: RandomForestRegressor<f64> = bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap(); let deserialized_forest: RandomForestRegressor<f64> =
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
assert_eq!(forest, deserialized_forest); assert_eq!(forest, deserialized_forest);
} }
}
}
+9 -9
View File
@@ -1,12 +1,12 @@
pub mod linear;
pub mod neighbors;
pub mod ensemble;
pub mod tree;
pub mod cluster;
pub mod decomposition;
pub mod linalg;
pub mod math;
pub mod algorithm; pub mod algorithm;
pub mod cluster;
pub mod common; pub mod common;
pub mod decomposition;
pub mod ensemble;
pub mod linalg;
pub mod linear;
pub mod math;
pub mod metrics;
pub mod neighbors;
pub mod optimization; pub mod optimization;
pub mod metrics; pub mod tree;
+110 -108
View File
@@ -1,75 +1,64 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
use num::complex::Complex;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
use num::complex::Complex;
use std::fmt::Debug; use std::fmt::Debug;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct EVD<T: FloatExt, M: BaseMatrix<T>> { pub struct EVD<T: FloatExt, M: BaseMatrix<T>> {
pub d: Vec<T>, pub d: Vec<T>,
pub e: Vec<T>, pub e: Vec<T>,
pub V: M pub V: M,
} }
impl<T: FloatExt, M: BaseMatrix<T>> EVD<T, M> { impl<T: FloatExt, M: BaseMatrix<T>> EVD<T, M> {
pub fn new(V: M, d: Vec<T>, e: Vec<T>) -> EVD<T, M> { pub fn new(V: M, d: Vec<T>, e: Vec<T>) -> EVD<T, M> {
EVD { EVD { d: d, e: e, V: V }
d: d,
e: e,
V: V
}
} }
} }
pub trait EVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> { pub trait EVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
fn evd(&self, symmetric: bool) -> EVD<T, Self> {
fn evd(&self, symmetric: bool) -> EVD<T, Self>{
self.clone().evd_mut(symmetric) self.clone().evd_mut(symmetric)
} }
fn evd_mut(mut self, symmetric: bool) -> EVD<T, Self>{ fn evd_mut(mut self, symmetric: bool) -> EVD<T, Self> {
let(nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();
if ncols != nrows { if ncols != nrows {
panic!("Matrix is not square: {} x {}", nrows, ncols); panic!("Matrix is not square: {} x {}", nrows, ncols);
} }
let n = nrows; let n = nrows;
let mut d = vec![T::zero(); n]; let mut d = vec![T::zero(); n];
let mut e = vec![T::zero(); n]; let mut e = vec![T::zero(); n];
let mut V; let mut V;
if symmetric { if symmetric {
V = self; V = self;
// Tridiagonalize. // Tridiagonalize.
tred2(&mut V, &mut d, &mut e); tred2(&mut V, &mut d, &mut e);
// Diagonalize. // Diagonalize.
tql2(&mut V, &mut d, &mut e); tql2(&mut V, &mut d, &mut e);
} else { } else {
let scale = balance(&mut self); let scale = balance(&mut self);
let perm = elmhes(&mut self);
V = Self::eye(n); let perm = elmhes(&mut self);
eltran(&self, &mut V, &perm); V = Self::eye(n);
eltran(&self, &mut V, &perm);
hqr2(&mut self, &mut V, &mut d, &mut e); hqr2(&mut self, &mut V, &mut d, &mut e);
balbak(&mut V, &scale); balbak(&mut V, &scale);
sort(&mut d, &mut e, &mut V); sort(&mut d, &mut e, &mut V);
} }
EVD { EVD { V: V, d: d, e: e }
V: V,
d: d,
e: e
}
} }
} }
fn tred2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) { fn tred2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
let (n, _) = V.shape(); let (n, _) = V.shape();
for i in 0..n { for i in 0..n {
d[i] = V.get(n - 1, i); d[i] = V.get(n - 1, i);
@@ -131,7 +120,7 @@ fn tred2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T
for j in 0..i { for j in 0..i {
f = d[j]; f = d[j];
g = e[j]; g = e[j];
for k in j..=i-1 { for k in j..=i - 1 {
V.sub_element_mut(k, j, f * e[k] + g * d[k]); V.sub_element_mut(k, j, f * e[k] + g * d[k]);
} }
d[j] = V.get(i - 1, j); d[j] = V.get(i - 1, j);
@@ -139,10 +128,10 @@ fn tred2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T
} }
} }
d[i] = h; d[i] = h;
} }
// Accumulate transformations. // Accumulate transformations.
for i in 0..n-1 { for i in 0..n - 1 {
V.set(n - 1, i, V.get(i, i)); V.set(n - 1, i, V.get(i, i));
V.set(i, i, T::one()); V.set(i, i, T::one());
let h = d[i + 1]; let h = d[i + 1];
@@ -156,7 +145,7 @@ fn tred2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T
g = g + V.get(k, i + 1) * V.get(k, j); g = g + V.get(k, i + 1) * V.get(k, j);
} }
for k in 0..=i { for k in 0..=i {
V.sub_element_mut(k, j, g * d[k]); V.sub_element_mut(k, j, g * d[k]);
} }
} }
} }
@@ -193,7 +182,7 @@ fn tql2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>
break; break;
} }
m += 1; m += 1;
} else { } else {
break; break;
} }
} }
@@ -219,7 +208,7 @@ fn tql2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>
d[l + 1] = e[l] * (p + r); d[l + 1] = e[l] * (p + r);
let dl1 = d[l + 1]; let dl1 = d[l + 1];
let mut h = g - d[l]; let mut h = g - d[l];
for i in l+2..n { for i in l + 2..n {
d[i] = d[i] - h; d[i] = d[i] - h;
} }
f = f + h; f = f + h;
@@ -249,7 +238,7 @@ fn tql2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>
for k in 0..n { for k in 0..n {
h = V.get(k, i + 1); h = V.get(k, i + 1);
V.set(k, i + 1, s * V.get(k, i) + c * h); V.set(k, i + 1, s * V.get(k, i) + c * h);
V.set(k, i, c * V.get(k, i) - s * h); V.set(k, i, c * V.get(k, i) - s * h);
} }
} }
p = -s * s2 * c3 * el1 * e[l] / dl1; p = -s * s2 * c3 * el1 * e[l] / dl1;
@@ -267,7 +256,7 @@ fn tql2<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>
} }
// Sort eigenvalues and corresponding vectors. // Sort eigenvalues and corresponding vectors.
for i in 0..n-1 { for i in 0..n - 1 {
let mut k = i; let mut k = i;
let mut p = d[i]; let mut p = d[i];
for j in i + 1..n { for j in i + 1..n {
@@ -294,8 +283,8 @@ fn balance<T: FloatExt, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
let (n, _) = A.shape(); let (n, _) = A.shape();
let mut scale = vec![T::one(); n]; let mut scale = vec![T::one(); n];
let t = T::from(0.95).unwrap(); let t = T::from(0.95).unwrap();
let mut done = false; let mut done = false;
@@ -345,7 +334,7 @@ fn elmhes<T: FloatExt, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
let (n, _) = A.shape(); let (n, _) = A.shape();
let mut perm = vec![0; n]; let mut perm = vec![0; n];
for m in 1..n-1 { for m in 1..n - 1 {
let mut x = T::zero(); let mut x = T::zero();
let mut i = m; let mut i = m;
for j in m..n { for j in m..n {
@@ -353,10 +342,10 @@ fn elmhes<T: FloatExt, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
x = A.get(j, m - 1); x = A.get(j, m - 1);
i = j; i = j;
} }
} }
perm[m] = i; perm[m] = i;
if i != m { if i != m {
for j in (m-1)..n { for j in (m - 1)..n {
let swap = A.get(i, j); let swap = A.get(i, j);
A.set(i, j, A.get(m, j)); A.set(i, j, A.get(m, j));
A.set(m, j, swap); A.set(m, j, swap);
@@ -366,7 +355,7 @@ fn elmhes<T: FloatExt, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
A.set(j, i, A.get(j, m)); A.set(j, i, A.get(j, m));
A.set(j, m, swap); A.set(j, m, swap);
} }
} }
if x != T::zero() { if x != T::zero() {
for i in (m + 1)..n { for i in (m + 1)..n {
let mut y = A.get(i, m - 1); let mut y = A.get(i, m - 1);
@@ -381,11 +370,11 @@ fn elmhes<T: FloatExt, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
} }
} }
} }
} }
} }
return perm; return perm;
} }
fn eltran<T: FloatExt, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &Vec<usize>) { fn eltran<T: FloatExt, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &Vec<usize>) {
let (n, _) = A.shape(); let (n, _) = A.shape();
@@ -405,27 +394,27 @@ fn eltran<T: FloatExt, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &Vec<usize>) {
} }
fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) { fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
let (n, _) = A.shape(); let (n, _) = A.shape();
let mut z = T::zero(); let mut z = T::zero();
let mut s = T::zero(); let mut s = T::zero();
let mut r = T::zero(); let mut r = T::zero();
let mut q = T::zero(); let mut q = T::zero();
let mut p = T::zero(); let mut p = T::zero();
let mut anorm = T::zero(); let mut anorm = T::zero();
for i in 0..n { for i in 0..n {
for j in i32::max(i as i32 - 1, 0)..n as i32 { for j in i32::max(i as i32 - 1, 0)..n as i32 {
anorm = anorm + A.get(i, j as usize).abs(); anorm = anorm + A.get(i, j as usize).abs();
} }
} }
let mut nn = n - 1; let mut nn = n - 1;
let mut t = T::zero(); let mut t = T::zero();
'outer: loop { 'outer: loop {
let mut its = 0; let mut its = 0;
loop { loop {
let mut l = nn; let mut l = nn;
while l > 0 { while l > 0 {
s = A.get(l - 1, l - 1).abs() + A.get(l, l).abs(); s = A.get(l - 1, l - 1).abs() + A.get(l, l).abs();
if s == T::zero() { if s == T::zero() {
s = anorm; s = anorm;
@@ -433,8 +422,8 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
if A.get(l, l - 1).abs() <= T::epsilon() * s { if A.get(l, l - 1).abs() <= T::epsilon() * s {
A.set(l, l - 1, T::zero()); A.set(l, l - 1, T::zero());
break; break;
} }
l -= 1; l -= 1;
} }
let mut x = A.get(nn, nn); let mut x = A.get(nn, nn);
if l == nn { if l == nn {
@@ -444,7 +433,7 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
break 'outer; break 'outer;
} else { } else {
nn -= 1; nn -= 1;
} }
} else { } else {
let mut y = A.get(nn - 1, nn - 1); let mut y = A.get(nn - 1, nn - 1);
let mut w = A.get(nn, nn - 1) * A.get(nn - 1, nn); let mut w = A.get(nn, nn - 1) * A.get(nn - 1, nn);
@@ -453,7 +442,7 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
q = p * p + w; q = p * p + w;
z = q.abs().sqrt(); z = q.abs().sqrt();
x = x + t; x = x + t;
A.set(nn, nn, x ); A.set(nn, nn, x);
A.set(nn - 1, nn - 1, y + t); A.set(nn - 1, nn - 1, y + t);
if q >= T::zero() { if q >= T::zero() {
z = p + z.copysign(p); z = p + z.copysign(p);
@@ -469,7 +458,7 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
r = (p * p + q * q).sqrt(); r = (p * p + q * q).sqrt();
p = p / r; p = p / r;
q = q / r; q = q / r;
for j in nn-1..n { for j in nn - 1..n {
z = A.get(nn - 1, j); z = A.get(nn - 1, j);
A.set(nn - 1, j, q * z + p * A.get(nn, j)); A.set(nn - 1, j, q * z + p * A.get(nn, j));
A.set(nn, j, q * A.get(nn, j) - p * z); A.set(nn, j, q * A.get(nn, j) - p * z);
@@ -490,19 +479,19 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
d[nn - 1] = d[nn]; d[nn - 1] = d[nn];
e[nn - 1] = -e[nn]; e[nn - 1] = -e[nn];
} }
if nn <= 1 { if nn <= 1 {
break 'outer; break 'outer;
} else { } else {
nn -= 2; nn -= 2;
} }
} else { } else {
if its == 30 { if its == 30 {
panic!("Too many iterations in hqr"); panic!("Too many iterations in hqr");
} }
if its == 10 || its == 20 { if its == 10 || its == 20 {
t = t + x; t = t + x;
for i in 0..nn+1 { for i in 0..nn + 1 {
A.sub_element_mut(i, i, x); A.sub_element_mut(i, i, x);
} }
s = A.get(nn, nn - 1).abs() + A.get(nn - 1, nn - 2).abs(); s = A.get(nn, nn - 1).abs() + A.get(nn - 1, nn - 2).abs();
@@ -527,14 +516,15 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
break; break;
} }
let u = A.get(m, m - 1).abs() * (q.abs() + r.abs()); let u = A.get(m, m - 1).abs() * (q.abs() + r.abs());
let v = p.abs() * (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs()); let v = p.abs()
* (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs());
if u <= T::epsilon() * v { if u <= T::epsilon() * v {
break; break;
} }
m -= 1; m -= 1;
} }
for i in m..nn-1 { for i in m..nn - 1 {
A.set(i + 2, i , T::zero()); A.set(i + 2, i, T::zero());
if i != m { if i != m {
A.set(i + 2, i - 1, T::zero()); A.set(i + 2, i - 1, T::zero());
} }
@@ -547,7 +537,7 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
if k + 1 != nn { if k + 1 != nn {
r = A.get(k + 2, k - 1); r = A.get(k + 2, k - 1);
} }
x = p.abs() + q.abs() +r.abs(); x = p.abs() + q.abs() + r.abs();
if x != T::zero() { if x != T::zero() {
p = p / x; p = p / x;
q = q / x; q = q / x;
@@ -583,8 +573,8 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
mmin = nn; mmin = nn;
} else { } else {
mmin = k + 3; mmin = k + 3;
} }
for i in 0..mmin+1 { for i in 0..mmin + 1 {
p = x * A.get(i, k) + y * A.get(i, k + 1); p = x * A.get(i, k) + y * A.get(i, k + 1);
if k + 1 != nn { if k + 1 != nn {
p = p + z * A.get(i, k + 2); p = p + z * A.get(i, k + 2);
@@ -609,7 +599,7 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
if l + 1 >= nn { if l + 1 >= nn {
break; break;
} }
}; }
} }
if anorm != T::zero() { if anorm != T::zero() {
@@ -659,7 +649,7 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
} }
} }
} }
if i == 0{ if i == 0 {
break; break;
} else { } else {
i -= 1; i -= 1;
@@ -672,13 +662,14 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
A.set(na, na, q / A.get(nn, na)); A.set(na, na, q / A.get(nn, na));
A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na)); A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na));
} else { } else {
let temp = Complex::new(T::zero(), -A.get(na, nn)) / Complex::new(A.get(na, na) - p, q); let temp = Complex::new(T::zero(), -A.get(na, nn))
/ Complex::new(A.get(na, na) - p, q);
A.set(na, na, temp.re); A.set(na, na, temp.re);
A.set(na, nn, temp.im); A.set(na, nn, temp.im);
} }
A.set(nn, na, T::zero()); A.set(nn, na, T::zero());
A.set(nn, nn, T::one()); A.set(nn, nn, T::one());
if nn >= 2 { if nn >= 2 {
for i in (0..nn - 1).rev() { for i in (0..nn - 1).rev() {
let w = A.get(i, i) - p; let w = A.get(i, i) - p;
let mut ra = T::zero(); let mut ra = T::zero();
@@ -694,25 +685,40 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
} else { } else {
m = i; m = i;
if e[i] == T::zero() { if e[i] == T::zero() {
let temp = Complex::new(-ra, -sa) / Complex::new(w, q); let temp = Complex::new(-ra, -sa) / Complex::new(w, q);
A.set(i, na, temp.re); A.set(i, na, temp.re);
A.set(i, nn, temp.im); A.set(i, nn, temp.im);
} else { } else {
let x = A.get(i, i + 1); let x = A.get(i, i + 1);
let y = A.get(i + 1, i); let y = A.get(i + 1, i);
let mut vr = (d[i] - p).powf(T::two()) + (e[i]).powf(T::two()) - q * q; let mut vr =
(d[i] - p).powf(T::two()) + (e[i]).powf(T::two()) - q * q;
let vi = T::two() * q * (d[i] - p); let vi = T::two() * q * (d[i] - p);
if vr == T::zero() && vi == T::zero() { if vr == T::zero() && vi == T::zero() {
vr = T::epsilon() * anorm * (w.abs() + q.abs() + x.abs() + y.abs() + z.abs()); vr = T::epsilon()
* anorm
* (w.abs() + q.abs() + x.abs() + y.abs() + z.abs());
} }
let temp = Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra) / Complex::new(vr, vi); let temp =
Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra)
/ Complex::new(vr, vi);
A.set(i, na, temp.re); A.set(i, na, temp.re);
A.set(i, nn, temp.im); A.set(i, nn, temp.im);
if x.abs() > z.abs() + q.abs() { if x.abs() > z.abs() + q.abs() {
A.set(i + 1, na, (-ra - w * A.get(i, na) + q * A.get(i, nn)) / x); A.set(
A.set(i + 1, nn, (-sa - w * A.get(i, nn) - q * A.get(i, na)) / x); i + 1,
na,
(-ra - w * A.get(i, na) + q * A.get(i, nn)) / x,
);
A.set(
i + 1,
nn,
(-sa - w * A.get(i, nn) - q * A.get(i, na)) / x,
);
} else { } else {
let temp = Complex::new(-r - y * A.get(i, na), -s - y * A.get(i, nn)) / Complex::new(z, q); let temp =
Complex::new(-r - y * A.get(i, na), -s - y * A.get(i, nn))
/ Complex::new(z, q);
A.set(i + 1, na, temp.re); A.set(i + 1, na, temp.re);
A.set(i + 1, nn, temp.im); A.set(i + 1, nn, temp.im);
} }
@@ -728,7 +734,7 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
} }
} }
} }
} }
for j in (0..n).rev() { for j in (0..n).rev() {
for i in 0..n { for i in 0..n {
@@ -739,7 +745,7 @@ fn hqr2<T: FloatExt, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e:
V.set(i, j, z); V.set(i, j, z);
} }
} }
} }
} }
fn balbak<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, scale: &Vec<T>) { fn balbak<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, scale: &Vec<T>) {
@@ -751,7 +757,7 @@ fn balbak<T: FloatExt, M: BaseMatrix<T>>(V: &mut M, scale: &Vec<T>) {
} }
} }
fn sort<T: FloatExt, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut M) { fn sort<T: FloatExt, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut M) {
let n = d.len(); let n = d.len();
let mut temp = vec![T::zero(); n]; let mut temp = vec![T::zero(); n];
for j in 1..n { for j in 1..n {
@@ -781,74 +787,72 @@ fn sort<T: FloatExt, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut M
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test] #[test]
fn decompose_symmetric() { fn decompose_symmetric() {
let A = DenseMatrix::from_array(&[ let A = DenseMatrix::from_array(&[
&[0.9000, 0.4000, 0.7000], &[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000], &[0.4000, 0.5000, 0.3000],
&[0.7000, 0.3000, 0.8000]]); &[0.7000, 0.3000, 0.8000],
]);
let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834]; let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
let eigen_vectors = DenseMatrix::from_array(&[ let eigen_vectors = DenseMatrix::from_array(&[
&[0.6881997, -0.07121225, 0.7220180], &[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886], &[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.6391588] &[0.6240573, -0.44947578, -0.6391588],
]); ]);
let evd = A.evd(true); let evd = A.evd(true);
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4)); assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() { for i in 0..eigen_values.len() {
assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4); assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4);
} }
for i in 0..eigen_values.len() { for i in 0..eigen_values.len() {
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON); assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
} }
} }
#[test] #[test]
fn decompose_asymmetric() { fn decompose_asymmetric() {
let A = DenseMatrix::from_array(&[ let A = DenseMatrix::from_array(&[
&[0.9000, 0.4000, 0.7000], &[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000], &[0.4000, 0.5000, 0.3000],
&[0.8000, 0.3000, 0.8000]]); &[0.8000, 0.3000, 0.8000],
]);
let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735]; let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735];
let eigen_vectors = DenseMatrix::from_array(&[ let eigen_vectors = DenseMatrix::from_array(&[
&[0.7178958, 0.05322098, 0.6812010], &[0.7178958, 0.05322098, 0.6812010],
&[0.3837711, -0.84702111, -0.1494582], &[0.3837711, -0.84702111, -0.1494582],
&[0.6952105, 0.43984484, -0.7036135] &[0.6952105, 0.43984484, -0.7036135],
]); ]);
let evd = A.evd(false); let evd = A.evd(false);
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4)); assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() { for i in 0..eigen_values.len() {
assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4); assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4);
} }
for i in 0..eigen_values.len() { for i in 0..eigen_values.len() {
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON); assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
} }
} }
#[test] #[test]
fn decompose_complex() { fn decompose_complex() {
let A = DenseMatrix::from_array(&[ let A = DenseMatrix::from_array(&[
&[3.0, -2.0, 1.0, 1.0], &[3.0, -2.0, 1.0, 1.0],
&[4.0, -1.0, 1.0, 1.0], &[4.0, -1.0, 1.0, 1.0],
&[1.0, 1.0, 3.0, -2.0], &[1.0, 1.0, 3.0, -2.0],
&[1.0, 1.0, 4.0, -1.0]]); &[1.0, 1.0, 4.0, -1.0],
]);
let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0]; let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0];
let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361]; let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361];
@@ -857,19 +861,17 @@ mod tests {
&[-0.9159, -0.1378, 0.3816, -0.0806], &[-0.9159, -0.1378, 0.3816, -0.0806],
&[-0.6707, 0.1059, 0.901, 0.6289], &[-0.6707, 0.1059, 0.901, 0.6289],
&[0.9159, -0.1378, 0.3816, 0.0806], &[0.9159, -0.1378, 0.3816, 0.0806],
&[0.6707, 0.1059, 0.901, -0.6289] &[0.6707, 0.1059, 0.901, -0.6289],
]); ]);
let evd = A.evd(false); let evd = A.evd(false);
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4)); assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values_d.len() { for i in 0..eigen_values_d.len() {
assert!((eigen_values_d[i] - evd.d[i]).abs() < 1e-4); assert!((eigen_values_d[i] - evd.d[i]).abs() < 1e-4);
} }
for i in 0..eigen_values_e.len() { for i in 0..eigen_values_e.len() {
assert!((eigen_values_e[i] - evd.e[i]).abs() < 1e-4); assert!((eigen_values_e[i] - evd.e[i]).abs() < 1e-4);
} }
} }
}
}
+46 -52
View File
@@ -3,22 +3,21 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use crate::math::num::FloatExt;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::math::num::FloatExt;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LU<T: FloatExt, M: BaseMatrix<T>> { pub struct LU<T: FloatExt, M: BaseMatrix<T>> {
LU: M, LU: M,
pivot: Vec<usize>, pivot: Vec<usize>,
pivot_sign: i8, pivot_sign: i8,
singular: bool, singular: bool,
phantom: PhantomData<T> phantom: PhantomData<T>,
} }
impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> { impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
pub fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> { pub fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> {
let (_, n) = LU.shape();
let (_, n) = LU.shape();
let mut singular = false; let mut singular = false;
for j in 0..n { for j in 0..n {
@@ -33,7 +32,7 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
pivot: pivot, pivot: pivot,
pivot_sign: pivot_sign, pivot_sign: pivot_sign,
singular: singular, singular: singular,
phantom: PhantomData phantom: PhantomData,
} }
} }
@@ -63,24 +62,24 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
for i in 0..n_rows { for i in 0..n_rows {
for j in 0..n_cols { for j in 0..n_cols {
if i <= j { if i <= j {
U.set(i, j, self.LU.get(i, j)); U.set(i, j, self.LU.get(i, j));
} else { } else {
U.set(i, j, T::zero()); U.set(i, j, T::zero());
} }
} }
} }
U U
} }
pub fn pivot(&self) -> M { pub fn pivot(&self) -> M {
let (_, n) = self.LU.shape(); let (_, n) = self.LU.shape();
let mut piv = M::zeros(n, n); let mut piv = M::zeros(n, n);
for i in 0..n { for i in 0..n {
piv.set(i, self.pivot[i], T::one()); piv.set(i, self.pivot[i], T::one());
} }
piv piv
} }
@@ -92,7 +91,7 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
} }
let mut inv = M::zeros(n, n); let mut inv = M::zeros(n, n);
for i in 0..n { for i in 0..n {
inv.set(i, i, T::one()); inv.set(i, i, T::one());
} }
@@ -106,7 +105,10 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
let (b_m, b_n) = b.shape(); let (b_m, b_n) = b.shape();
if b_m != m { if b_m != m {
panic!("Row dimensions do not agree: A is {} x {}, but B is {} x {}", m, n, b_m, b_n); panic!(
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
m, n, b_m, b_n
);
} }
if self.singular { if self.singular {
@@ -120,9 +122,9 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
X.set(i, j, b.get(self.pivot[i], j)); X.set(i, j, b.get(self.pivot[i], j));
} }
} }
for k in 0..n { for k in 0..n {
for i in k+1..n { for i in k + 1..n {
for j in 0..b_n { for j in 0..b_n {
X.sub_element_mut(i, j, X.get(k, j) * self.LU.get(i, k)); X.sub_element_mut(i, j, X.get(k, j) * self.LU.get(i, k));
} }
@@ -140,7 +142,7 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
} }
} }
} }
for j in 0..b_n { for j in 0..b_n {
for i in 0..m { for i in 0..m {
b.set(i, j, X.get(i, j)); b.set(i, j, X.get(i, j));
@@ -148,20 +150,16 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
} }
b b
} }
} }
pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> { pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
fn lu(&self) -> LU<T, Self> { fn lu(&self) -> LU<T, Self> {
self.clone().lu_mut() self.clone().lu_mut()
} }
fn lu_mut(mut self) -> LU<T, Self> { fn lu_mut(mut self) -> LU<T, Self> {
let (m, n) = self.shape();
let (m, n) = self.shape();
let mut piv = vec![0; m]; let mut piv = vec![0; m];
for i in 0..m { for i in 0..m {
@@ -172,7 +170,6 @@ pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
let mut LUcolj = vec![T::zero(); m]; let mut LUcolj = vec![T::zero(); m];
for j in 0..n { for j in 0..n {
for i in 0..m { for i in 0..m {
LUcolj[i] = self.get(i, j); LUcolj[i] = self.get(i, j);
} }
@@ -189,7 +186,7 @@ pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
} }
let mut p = j; let mut p = j;
for i in j+1..m { for i in j + 1..m {
if LUcolj[i].abs() > LUcolj[p].abs() { if LUcolj[i].abs() > LUcolj[p].abs() {
p = i; p = i;
} }
@@ -205,50 +202,47 @@ pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
piv[j] = k; piv[j] = k;
pivsign = -pivsign; pivsign = -pivsign;
} }
if j < m && self.get(j, j) != T::zero() { if j < m && self.get(j, j) != T::zero() {
for i in j+1..m { for i in j + 1..m {
self.div_element_mut(i, j, self.get(j, j)); self.div_element_mut(i, j, self.get(j, j));
} }
} }
} }
LU::new(self, piv, pivsign) LU::new(self, piv, pivsign)
} }
fn lu_solve_mut(self, b: Self) -> Self { fn lu_solve_mut(self, b: Self) -> Self {
self.lu_mut().solve(b)
self.lu_mut().solve(b) }
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[test] #[test]
fn decompose() { fn decompose() {
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]); let expected_L = DenseMatrix::from_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]);
let expected_L = DenseMatrix::from_array(&[&[1. , 0. , 0. ], &[0. , 1. , 0. ], &[0.2, 0.8, 1. ]]); let expected_U = DenseMatrix::from_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
let expected_U = DenseMatrix::from_array(&[&[ 5., 6., 0.], &[ 0., 1., 5.], &[ 0., 0., -1.]]); let expected_pivot =
let expected_pivot = DenseMatrix::from_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]); DenseMatrix::from_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
let lu = a.lu(); let lu = a.lu();
assert!(lu.L().approximate_eq(&expected_L, 1e-4)); assert!(lu.L().approximate_eq(&expected_L, 1e-4));
assert!(lu.U().approximate_eq(&expected_U, 1e-4)); assert!(lu.U().approximate_eq(&expected_U, 1e-4));
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4)); assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
} }
#[test] #[test]
fn inverse() { fn inverse() {
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]); let expected =
let expected = DenseMatrix::from_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]); DenseMatrix::from_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
let a_inv = a.lu().inverse(); let a_inv = a.lu().inverse();
println!("{}", a_inv); println!("{}", a_inv);
assert!(a_inv.approximate_eq(&expected, 1e-4)); assert!(a_inv.approximate_eq(&expected, 1e-4));
} }
} }
+46 -42
View File
@@ -1,45 +1,43 @@
pub mod naive;
pub mod qr;
pub mod svd;
pub mod evd; pub mod evd;
pub mod lu; pub mod lu;
pub mod ndarray_bindings; pub mod naive;
pub mod nalgebra_bindings; pub mod nalgebra_bindings;
pub mod ndarray_bindings;
pub mod qr;
pub mod svd;
use std::ops::Range;
use std::fmt::{Debug, Display}; use std::fmt::{Debug, Display};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::Range;
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
use svd::SVDDecomposableMatrix;
use evd::EVDDecomposableMatrix; use evd::EVDDecomposableMatrix;
use qr::QRDecomposableMatrix;
use lu::LUDecomposableMatrix; use lu::LUDecomposableMatrix;
use qr::QRDecomposableMatrix;
use svd::SVDDecomposableMatrix;
pub trait BaseVector<T: FloatExt>: Clone + Debug { pub trait BaseVector<T: FloatExt>: Clone + Debug {
fn get(&self, i: usize) -> T;
fn get(&self, i: usize) -> T;
fn set(&mut self, i: usize, x: T); fn set(&mut self, i: usize, x: T);
fn len(&self) -> usize; fn len(&self) -> usize;
} }
pub trait BaseMatrix<T: FloatExt>: Clone + Debug { pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
type RowVector: BaseVector<T> + Clone + Debug;
type RowVector: BaseVector<T> + Clone + Debug;
fn from_row_vector(vec: Self::RowVector) -> Self; fn from_row_vector(vec: Self::RowVector) -> Self;
fn to_row_vector(self) -> Self::RowVector; fn to_row_vector(self) -> Self::RowVector;
fn get(&self, row: usize, col: usize) -> T; fn get(&self, row: usize, col: usize) -> T;
fn get_row_as_vec(&self, row: usize) -> Vec<T>; fn get_row_as_vec(&self, row: usize) -> Vec<T>;
fn get_col_as_vec(&self, col: usize) -> Vec<T>; fn get_col_as_vec(&self, col: usize) -> Vec<T>;
fn set(&mut self, row: usize, col: usize, x: T); fn set(&mut self, row: usize, col: usize, x: T);
fn eye(size: usize) -> Self; fn eye(size: usize) -> Self;
@@ -51,17 +49,17 @@ pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
fn fill(nrows: usize, ncols: usize, value: T) -> Self; fn fill(nrows: usize, ncols: usize, value: T) -> Self;
fn shape(&self) -> (usize, usize); fn shape(&self) -> (usize, usize);
fn v_stack(&self, other: &Self) -> Self; fn v_stack(&self, other: &Self) -> Self;
fn h_stack(&self, other: &Self) -> Self; fn h_stack(&self, other: &Self) -> Self;
fn dot(&self, other: &Self) -> Self; fn dot(&self, other: &Self) -> Self;
fn vector_dot(&self, other: &Self) -> T; fn vector_dot(&self, other: &Self) -> T;
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self; fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self;
fn approximate_eq(&self, other: &Self, error: T) -> bool; fn approximate_eq(&self, other: &Self, error: T) -> bool;
@@ -113,37 +111,37 @@ pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
fn div_scalar_mut(&mut self, scalar: T) -> &Self; fn div_scalar_mut(&mut self, scalar: T) -> &Self;
fn add_scalar(&self, scalar: T) -> Self{ fn add_scalar(&self, scalar: T) -> Self {
let mut r = self.clone(); let mut r = self.clone();
r.add_scalar_mut(scalar); r.add_scalar_mut(scalar);
r r
} }
fn sub_scalar(&self, scalar: T) -> Self{ fn sub_scalar(&self, scalar: T) -> Self {
let mut r = self.clone(); let mut r = self.clone();
r.sub_scalar_mut(scalar); r.sub_scalar_mut(scalar);
r r
} }
fn mul_scalar(&self, scalar: T) -> Self{ fn mul_scalar(&self, scalar: T) -> Self {
let mut r = self.clone(); let mut r = self.clone();
r.mul_scalar_mut(scalar); r.mul_scalar_mut(scalar);
r r
} }
fn div_scalar(&self, scalar: T) -> Self{ fn div_scalar(&self, scalar: T) -> Self {
let mut r = self.clone(); let mut r = self.clone();
r.div_scalar_mut(scalar); r.div_scalar_mut(scalar);
r r
} }
fn transpose(&self) -> Self; fn transpose(&self) -> Self;
fn rand(nrows: usize, ncols: usize) -> Self; fn rand(nrows: usize, ncols: usize) -> Self;
fn norm2(&self) -> T; fn norm2(&self) -> T;
fn norm(&self, p:T) -> T; fn norm(&self, p: T) -> T;
fn column_mean(&self) -> Vec<T>; fn column_mean(&self) -> Vec<T>;
@@ -153,7 +151,7 @@ pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
let mut result = self.clone(); let mut result = self.clone();
result.negative_mut(); result.negative_mut();
result result
} }
fn reshape(&self, nrows: usize, ncols: usize) -> Self; fn reshape(&self, nrows: usize, ncols: usize) -> Self;
@@ -169,9 +167,9 @@ pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
fn sum(&self) -> T; fn sum(&self) -> T;
fn max_diff(&self, other: &Self) -> T; fn max_diff(&self, other: &Self) -> T;
fn softmax_mut(&mut self); fn softmax_mut(&mut self);
fn pow_mut(&mut self, p: T) -> &Self; fn pow_mut(&mut self, p: T) -> &Self;
@@ -181,22 +179,30 @@ pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
result result
} }
fn argmax(&self) -> Vec<usize>; fn argmax(&self) -> Vec<usize>;
fn unique(&self) -> Vec<T>;
fn cov(&self) -> Self;
fn unique(&self) -> Vec<T>;
fn cov(&self) -> Self;
} }
pub trait Matrix<T: FloatExt>: BaseMatrix<T> + SVDDecomposableMatrix<T> + EVDDecomposableMatrix<T> + QRDecomposableMatrix<T> + LUDecomposableMatrix<T> + PartialEq + Display {} pub trait Matrix<T: FloatExt>:
BaseMatrix<T>
+ SVDDecomposableMatrix<T>
+ EVDDecomposableMatrix<T>
+ QRDecomposableMatrix<T>
+ LUDecomposableMatrix<T>
+ PartialEq
+ Display
{
}
pub fn row_iter<F: FloatExt, M: BaseMatrix<F>>(m: &M) -> RowIter<F, M> { pub fn row_iter<F: FloatExt, M: BaseMatrix<F>>(m: &M) -> RowIter<F, M> {
RowIter{ RowIter {
m: m, m: m,
pos: 0, pos: 0,
max_pos: m.shape().0, max_pos: m.shape().0,
phantom: PhantomData phantom: PhantomData,
} }
} }
@@ -204,11 +210,10 @@ pub struct RowIter<'a, T: FloatExt, M: BaseMatrix<T>> {
m: &'a M, m: &'a M,
pos: usize, pos: usize,
max_pos: usize, max_pos: usize,
phantom: PhantomData<&'a T> phantom: PhantomData<&'a T>,
} }
impl<'a, T: FloatExt, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> { impl<'a, T: FloatExt, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
type Item = Vec<T>; type Item = Vec<T>;
fn next(&mut self) -> Option<Vec<T>> { fn next(&mut self) -> Option<Vec<T>> {
@@ -221,5 +226,4 @@ impl<'a, T: FloatExt, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
self.pos += 1; self.pos += 1;
res res
} }
}
}
+245 -259
View File
@@ -1,26 +1,26 @@
extern crate num; extern crate num;
use std::ops::Range;
use std::fmt; use std::fmt;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::Range;
use serde::{Serialize, Deserialize}; use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
use serde::ser::{Serializer, SerializeStruct}; use serde::ser::{SerializeStruct, Serializer};
use serde::de::{Deserializer, Visitor, SeqAccess, MapAccess}; use serde::{Deserialize, Serialize};
use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::lu::LUDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::Matrix; use crate::linalg::Matrix;
pub use crate::linalg::{BaseMatrix, BaseVector}; pub use crate::linalg::{BaseMatrix, BaseVector};
use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use crate::linalg::lu::LUDecomposableMatrix;
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
impl<T: FloatExt> BaseVector<T> for Vec<T> { impl<T: FloatExt> BaseVector<T> for Vec<T> {
fn get(&self, i: usize) -> T { fn get(&self, i: usize) -> T {
self[i] self[i]
} }
fn set(&mut self, i: usize, x: T){ fn set(&mut self, i: usize, x: T) {
self[i] = x self[i] = x
} }
@@ -31,32 +31,34 @@ impl<T: FloatExt> BaseVector<T> for Vec<T> {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct DenseMatrix<T: FloatExt> { pub struct DenseMatrix<T: FloatExt> {
ncols: usize, ncols: usize,
nrows: usize, nrows: usize,
values: Vec<T> values: Vec<T>,
} }
impl<T: FloatExt> fmt::Display for DenseMatrix<T> { impl<T: FloatExt> fmt::Display for DenseMatrix<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut rows: Vec<Vec<f64>> = Vec::new(); let mut rows: Vec<Vec<f64>> = Vec::new();
for r in 0..self.nrows { for r in 0..self.nrows {
rows.push(self.get_row_as_vec(r).iter().map(|x| (x.to_f64().unwrap() * 1e4).round() / 1e4 ).collect()); rows.push(
} self.get_row_as_vec(r)
.iter()
.map(|x| (x.to_f64().unwrap() * 1e4).round() / 1e4)
.collect(),
);
}
write!(f, "{:?}", rows) write!(f, "{:?}", rows)
} }
} }
impl<T: FloatExt> DenseMatrix<T> { impl<T: FloatExt> DenseMatrix<T> {
fn new(nrows: usize, ncols: usize, values: Vec<T>) -> Self { fn new(nrows: usize, ncols: usize, values: Vec<T>) -> Self {
DenseMatrix { DenseMatrix {
ncols: ncols, ncols: ncols,
nrows: nrows, nrows: nrows,
values: values values: values,
} }
} }
pub fn from_array(values: &[&[T]]) -> Self { pub fn from_array(values: &[&[T]]) -> Self {
DenseMatrix::from_vec(&values.into_iter().map(|row| Vec::from(*row)).collect()) DenseMatrix::from_vec(&values.into_iter().map(|row| Vec::from(*row)).collect())
@@ -64,11 +66,14 @@ impl<T: FloatExt> DenseMatrix<T> {
pub fn from_vec(values: &Vec<Vec<T>>) -> DenseMatrix<T> { pub fn from_vec(values: &Vec<Vec<T>>) -> DenseMatrix<T> {
let nrows = values.len(); let nrows = values.len();
let ncols = values.first().unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector")).len(); let ncols = values
.first()
.unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector"))
.len();
let mut m = DenseMatrix { let mut m = DenseMatrix {
ncols: ncols, ncols: ncols,
nrows: nrows, nrows: nrows,
values: vec![T::zero(); ncols*nrows] values: vec![T::zero(); ncols * nrows],
}; };
for row in 0..nrows { for row in 0..nrows {
for col in 0..ncols { for col in 0..ncols {
@@ -76,17 +81,17 @@ impl<T: FloatExt> DenseMatrix<T> {
} }
} }
m m
} }
pub fn vector_from_array(values: &[T]) -> Self { pub fn vector_from_array(values: &[T]) -> Self {
DenseMatrix::vector_from_vec(Vec::from(values)) DenseMatrix::vector_from_vec(Vec::from(values))
} }
pub fn vector_from_vec(values: Vec<T>) -> Self { pub fn vector_from_vec(values: Vec<T>) -> Self {
DenseMatrix { DenseMatrix {
ncols: values.len(), ncols: values.len(),
nrows: 1, nrows: 1,
values: values values: values,
} }
} }
@@ -98,12 +103,11 @@ impl<T: FloatExt> DenseMatrix<T> {
for i in 0..self.values.len() { for i in 0..self.values.len() {
self.values[i] = self.values[i] / b.values[i]; self.values[i] = self.values[i] / b.values[i];
} }
} }
pub fn get_raw_values(&self) -> &Vec<T> { pub fn get_raw_values(&self) -> &Vec<T> {
&self.values &self.values
} }
} }
impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> { impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
@@ -111,31 +115,37 @@ impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for Dens
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(field_identifier, rename_all = "lowercase")] #[serde(field_identifier, rename_all = "lowercase")]
enum Field { NRows, NCols, Values } enum Field {
NRows,
struct DenseMatrixVisitor<T: FloatExt + fmt::Debug>{ NCols,
t: PhantomData<T> Values,
} }
struct DenseMatrixVisitor<T: FloatExt + fmt::Debug> {
t: PhantomData<T>,
}
impl<'a, T: FloatExt + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor<T> { impl<'a, T: FloatExt + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor<T> {
type Value = DenseMatrix<T>; type Value = DenseMatrix<T>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct DenseMatrix") formatter.write_str("struct DenseMatrix")
} }
fn visit_seq<V>(self, mut seq: V) -> Result<DenseMatrix<T>, V::Error> fn visit_seq<V>(self, mut seq: V) -> Result<DenseMatrix<T>, V::Error>
where where
V: SeqAccess<'a>, V: SeqAccess<'a>,
{ {
let nrows = seq.next_element()? let nrows = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
let ncols = seq.next_element()? let ncols = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
let values = seq.next_element()? let values = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(2, &self))?; .ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
Ok(DenseMatrix::new(nrows, ncols, values)) Ok(DenseMatrix::new(nrows, ncols, values))
} }
@@ -176,23 +186,26 @@ impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for Dens
} }
} }
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"]; const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"];
deserializer.deserialize_struct("DenseMatrix", FIELDS, DenseMatrixVisitor { deserializer.deserialize_struct(
t: PhantomData "DenseMatrix",
}) FIELDS,
DenseMatrixVisitor { t: PhantomData },
)
} }
} }
impl<T: FloatExt + fmt::Debug + Serialize> Serialize for DenseMatrix<T> { impl<T: FloatExt + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where where
S: Serializer { S: Serializer,
{
let (nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();
let mut state = serializer.serialize_struct("DenseMatrix", 3)?; let mut state = serializer.serialize_struct("DenseMatrix", 3)?;
state.serialize_field("nrows", &nrows)?; state.serialize_field("nrows", &nrows)?;
state.serialize_field("ncols", &ncols)?; state.serialize_field("ncols", &ncols)?;
state.serialize_field("values", &self.values)?; state.serialize_field("values", &self.values)?;
state.end() state.end()
} }
} }
@@ -209,7 +222,7 @@ impl<T: FloatExt> Matrix<T> for DenseMatrix<T> {}
impl<T: FloatExt> PartialEq for DenseMatrix<T> { impl<T: FloatExt> PartialEq for DenseMatrix<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows { if self.ncols != other.ncols || self.nrows != other.nrows {
return false return false;
} }
let len = self.values.len(); let len = self.values.len();
@@ -235,26 +248,28 @@ impl<T: FloatExt> Into<Vec<T>> for DenseMatrix<T> {
} }
} }
impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> { impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
type RowVector = Vec<T>; type RowVector = Vec<T>;
fn from_row_vector(vec: Self::RowVector) -> Self{ fn from_row_vector(vec: Self::RowVector) -> Self {
DenseMatrix::new(1, vec.len(), vec) DenseMatrix::new(1, vec.len(), vec)
} }
fn to_row_vector(self) -> Self::RowVector{ fn to_row_vector(self) -> Self::RowVector {
self.to_raw_vector() self.to_raw_vector()
} }
fn get(&self, row: usize, col: usize) -> T { fn get(&self, row: usize, col: usize) -> T {
if row >= self.nrows || col >= self.ncols { if row >= self.nrows || col >= self.ncols {
panic!("Invalid index ({},{}) for {}x{} matrix", row, col, self.nrows, self.ncols); panic!(
"Invalid index ({},{}) for {}x{} matrix",
row, col, self.nrows, self.ncols
);
} }
self.values[col*self.nrows + row] self.values[col * self.nrows + row]
} }
fn get_row_as_vec(&self, row: usize) -> Vec<T>{ fn get_row_as_vec(&self, row: usize) -> Vec<T> {
let mut result = vec![T::zero(); self.ncols]; let mut result = vec![T::zero(); self.ncols];
for c in 0..self.ncols { for c in 0..self.ncols {
result[c] = self.get(row, c); result[c] = self.get(row, c);
@@ -262,16 +277,16 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
result result
} }
fn get_col_as_vec(&self, col: usize) -> Vec<T>{ fn get_col_as_vec(&self, col: usize) -> Vec<T> {
let mut result = vec![T::zero(); self.nrows]; let mut result = vec![T::zero(); self.nrows];
for r in 0..self.nrows { for r in 0..self.nrows {
result[r] = self.get(r, col); result[r] = self.get(r, col);
} }
result result
} }
fn set(&mut self, row: usize, col: usize, x: T) { fn set(&mut self, row: usize, col: usize, x: T) {
self.values[col*self.nrows + row] = x; self.values[col * self.nrows + row] = x;
} }
fn zeros(nrows: usize, ncols: usize) -> Self { fn zeros(nrows: usize, ncols: usize) -> Self {
@@ -280,7 +295,7 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
fn ones(nrows: usize, ncols: usize) -> Self { fn ones(nrows: usize, ncols: usize) -> Self {
DenseMatrix::fill(nrows, ncols, T::one()) DenseMatrix::fill(nrows, ncols, T::one())
} }
fn eye(size: usize) -> Self { fn eye(size: usize) -> Self {
let mut matrix = Self::zeros(size, size); let mut matrix = Self::zeros(size, size);
@@ -292,15 +307,15 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
return matrix; return matrix;
} }
fn to_raw_vector(&self) -> Vec<T>{ fn to_raw_vector(&self) -> Vec<T> {
let mut v = vec![T::zero(); self.nrows * self.ncols]; let mut v = vec![T::zero(); self.nrows * self.ncols];
for r in 0..self.nrows{ for r in 0..self.nrows {
for c in 0..self.ncols { for c in 0..self.ncols {
v[r * self.ncols + c] = self.get(r, c); v[r * self.ncols + c] = self.get(r, c);
} }
} }
v v
} }
@@ -314,25 +329,25 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
let mut result = Self::zeros(self.nrows + other.nrows, self.ncols); let mut result = Self::zeros(self.nrows + other.nrows, self.ncols);
for c in 0..self.ncols { for c in 0..self.ncols {
for r in 0..self.nrows+other.nrows { for r in 0..self.nrows + other.nrows {
if r < self.nrows { if r < self.nrows {
result.set(r, c, self.get(r, c)); result.set(r, c, self.get(r, c));
} else { } else {
result.set(r, c, other.get(r - self.nrows, c)); result.set(r, c, other.get(r - self.nrows, c));
} }
} }
} }
result result
} }
fn v_stack(&self, other: &Self) -> Self{ fn v_stack(&self, other: &Self) -> Self {
if self.nrows != other.nrows { if self.nrows != other.nrows {
panic!("Number of rows in both matrices should be equal"); panic!("Number of rows in both matrices should be equal");
} }
let mut result = Self::zeros(self.nrows, self.ncols + other.ncols); let mut result = Self::zeros(self.nrows, self.ncols + other.ncols);
for r in 0..self.nrows { for r in 0..self.nrows {
for c in 0..self.ncols+other.ncols { for c in 0..self.ncols + other.ncols {
if c < self.ncols { if c < self.ncols {
result.set(r, c, self.get(r, c)); result.set(r, c, self.get(r, c));
} else { } else {
result.set(r, c, other.get(r, c - self.ncols)); result.set(r, c, other.get(r, c - self.ncols));
@@ -343,7 +358,6 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
fn dot(&self, other: &Self) -> Self { fn dot(&self, other: &Self) -> Self {
if self.ncols != other.nrows { if self.ncols != other.nrows {
panic!("Number of rows of A should equal number of columns of B"); panic!("Number of rows of A should equal number of columns of B");
} }
@@ -361,7 +375,7 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
result result
} }
fn vector_dot(&self, other: &Self) -> T { fn vector_dot(&self, other: &Self) -> T {
if (self.nrows != 1 || self.nrows != 1) && (other.nrows != 1 || other.ncols != 1) { if (self.nrows != 1 || self.nrows != 1) && (other.nrows != 1 || other.ncols != 1) {
@@ -369,18 +383,17 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
if self.nrows * self.ncols != other.nrows * other.ncols { if self.nrows * self.ncols != other.nrows * other.ncols {
panic!("A and B should have the same size"); panic!("A and B should have the same size");
} }
let mut result = T::zero(); let mut result = T::zero();
for i in 0..(self.nrows * self.ncols) { for i in 0..(self.nrows * self.ncols) {
result = result + self.values[i] * other.values[i]; result = result + self.values[i] * other.values[i];
} }
result result
} }
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self { fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self {
let ncols = cols.len(); let ncols = cols.len();
let nrows = rows.len(); let nrows = rows.len();
@@ -388,22 +401,22 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
for r in rows.start..rows.end { for r in rows.start..rows.end {
for c in cols.start..cols.end { for c in cols.start..cols.end {
m.set(r-rows.start, c-cols.start, self.get(r, c)); m.set(r - rows.start, c - cols.start, self.get(r, c));
} }
} }
m m
} }
fn approximate_eq(&self, other: &Self, error: T) -> bool { fn approximate_eq(&self, other: &Self, error: T) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows { if self.ncols != other.ncols || self.nrows != other.nrows {
return false return false;
} }
for c in 0..self.ncols { for c in 0..self.ncols {
for r in 0..self.nrows { for r in 0..self.nrows {
if (self.get(r, c) - other.get(r, c)).abs() > error { if (self.get(r, c) - other.get(r, c)).abs() > error {
return false return false;
} }
} }
} }
@@ -418,7 +431,7 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
fn add_mut(&mut self, other: &Self) -> &Self { fn add_mut(&mut self, other: &Self) -> &Self {
if self.ncols != other.ncols || self.nrows != other.nrows { if self.ncols != other.ncols || self.nrows != other.nrows {
panic!("A and B should have the same shape"); panic!("A and B should have the same shape");
} }
for c in 0..self.ncols { for c in 0..self.ncols {
for r in 0..self.nrows { for r in 0..self.nrows {
self.add_element_mut(r, c, other.get(r, c)); self.add_element_mut(r, c, other.get(r, c));
@@ -431,7 +444,7 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
fn sub_mut(&mut self, other: &Self) -> &Self { fn sub_mut(&mut self, other: &Self) -> &Self {
if self.ncols != other.ncols || self.nrows != other.nrows { if self.ncols != other.ncols || self.nrows != other.nrows {
panic!("A and B should have the same shape"); panic!("A and B should have the same shape");
} }
for c in 0..self.ncols { for c in 0..self.ncols {
for r in 0..self.nrows { for r in 0..self.nrows {
self.sub_element_mut(r, c, other.get(r, c)); self.sub_element_mut(r, c, other.get(r, c));
@@ -444,7 +457,7 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
fn mul_mut(&mut self, other: &Self) -> &Self { fn mul_mut(&mut self, other: &Self) -> &Self {
if self.ncols != other.ncols || self.nrows != other.nrows { if self.ncols != other.ncols || self.nrows != other.nrows {
panic!("A and B should have the same shape"); panic!("A and B should have the same shape");
} }
for c in 0..self.ncols { for c in 0..self.ncols {
for r in 0..self.nrows { for r in 0..self.nrows {
self.mul_element_mut(r, c, other.get(r, c)); self.mul_element_mut(r, c, other.get(r, c));
@@ -457,7 +470,7 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
fn div_mut(&mut self, other: &Self) -> &Self { fn div_mut(&mut self, other: &Self) -> &Self {
if self.ncols != other.ncols || self.nrows != other.nrows { if self.ncols != other.ncols || self.nrows != other.nrows {
panic!("A and B should have the same shape"); panic!("A and B should have the same shape");
} }
for c in 0..self.ncols { for c in 0..self.ncols {
for r in 0..self.nrows { for r in 0..self.nrows {
self.div_element_mut(r, c, other.get(r, c)); self.div_element_mut(r, c, other.get(r, c));
@@ -468,26 +481,26 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
fn div_element_mut(&mut self, row: usize, col: usize, x: T) { fn div_element_mut(&mut self, row: usize, col: usize, x: T) {
self.values[col*self.nrows + row] = self.values[col*self.nrows + row] / x; self.values[col * self.nrows + row] = self.values[col * self.nrows + row] / x;
} }
fn mul_element_mut(&mut self, row: usize, col: usize, x: T) { fn mul_element_mut(&mut self, row: usize, col: usize, x: T) {
self.values[col*self.nrows + row] = self.values[col*self.nrows + row] * x; self.values[col * self.nrows + row] = self.values[col * self.nrows + row] * x;
} }
fn add_element_mut(&mut self, row: usize, col: usize, x: T) { fn add_element_mut(&mut self, row: usize, col: usize, x: T) {
self.values[col*self.nrows + row] = self.values[col*self.nrows + row] + x self.values[col * self.nrows + row] = self.values[col * self.nrows + row] + x
} }
fn sub_element_mut(&mut self, row: usize, col: usize, x: T) { fn sub_element_mut(&mut self, row: usize, col: usize, x: T) {
self.values[col*self.nrows + row] = self.values[col*self.nrows + row] - x; self.values[col * self.nrows + row] = self.values[col * self.nrows + row] - x;
} }
fn transpose(&self) -> Self { fn transpose(&self) -> Self {
let mut m = DenseMatrix { let mut m = DenseMatrix {
ncols: self.nrows, ncols: self.nrows,
nrows: self.ncols, nrows: self.ncols,
values: vec![T::zero(); self.ncols * self.nrows] values: vec![T::zero(); self.ncols * self.nrows],
}; };
for c in 0..self.ncols { for c in 0..self.ncols {
for r in 0..self.nrows { for r in 0..self.nrows {
@@ -495,19 +508,16 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
} }
m m
} }
fn rand(nrows: usize, ncols: usize) -> Self { fn rand(nrows: usize, ncols: usize) -> Self {
let values: Vec<T> = (0..nrows*ncols).map(|_| { let values: Vec<T> = (0..nrows * ncols).map(|_| T::rand()).collect();
T::rand()
}).collect();
DenseMatrix { DenseMatrix {
ncols: ncols, ncols: ncols,
nrows: nrows, nrows: nrows,
values: values values: values,
} }
} }
fn norm2(&self) -> T { fn norm2(&self) -> T {
let mut norm = T::zero(); let mut norm = T::zero();
@@ -519,21 +529,25 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
norm.sqrt() norm.sqrt()
} }
fn norm(&self, p:T) -> T { fn norm(&self, p: T) -> T {
if p.is_infinite() && p.is_sign_positive() { if p.is_infinite() && p.is_sign_positive() {
self.values.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b)) self.values
.iter()
.map(|x| x.abs())
.fold(T::neg_infinity(), |a, b| a.max(b))
} else if p.is_infinite() && p.is_sign_negative() { } else if p.is_infinite() && p.is_sign_negative() {
self.values.iter().map(|x| x.abs()).fold(T::infinity(), |a, b| a.min(b)) self.values
.iter()
.map(|x| x.abs())
.fold(T::infinity(), |a, b| a.min(b))
} else { } else {
let mut norm = T::zero(); let mut norm = T::zero();
for xi in self.values.iter() { for xi in self.values.iter() {
norm = norm + xi.abs().powf(p); norm = norm + xi.abs().powf(p);
} }
norm.powf(T::one()/p) norm.powf(T::one() / p)
} }
} }
@@ -585,11 +599,14 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
for i in 0..self.values.len() { for i in 0..self.values.len() {
self.values[i] = -self.values[i]; self.values[i] = -self.values[i];
} }
} }
fn reshape(&self, nrows: usize, ncols: usize) -> Self { fn reshape(&self, nrows: usize, ncols: usize) -> Self {
if self.nrows * self.ncols != nrows * ncols { if self.nrows * self.ncols != nrows * ncols {
panic!("Can't reshape {}x{} matrix into {}x{}.", self.nrows, self.ncols, nrows, ncols); panic!(
"Can't reshape {}x{} matrix into {}x{}.",
self.nrows, self.ncols, nrows, ncols
);
} }
let mut dst = DenseMatrix::zeros(nrows, ncols); let mut dst = DenseMatrix::zeros(nrows, ncols);
let mut dst_r = 0; let mut dst_r = 0;
@@ -609,9 +626,11 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
fn copy_from(&mut self, other: &Self) { fn copy_from(&mut self, other: &Self) {
if self.nrows != other.nrows || self.ncols != other.ncols { if self.nrows != other.nrows || self.ncols != other.ncols {
panic!("Can't copy {}x{} matrix into {}x{}.", self.nrows, self.ncols, other.nrows, other.ncols); panic!(
"Can't copy {}x{} matrix into {}x{}.",
self.nrows, self.ncols, other.nrows, other.ncols
);
} }
for i in 0..self.values.len() { for i in 0..self.values.len() {
@@ -619,20 +638,19 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
} }
fn abs_mut(&mut self) -> &Self{ fn abs_mut(&mut self) -> &Self {
for i in 0..self.values.len() { for i in 0..self.values.len() {
self.values[i] = self.values[i].abs(); self.values[i] = self.values[i].abs();
} }
self self
} }
fn max_diff(&self, other: &Self) -> T{ fn max_diff(&self, other: &Self) -> T {
let mut max_diff = T::zero(); let mut max_diff = T::zero();
for i in 0..self.values.len() { for i in 0..self.values.len() {
max_diff = max_diff.max((self.values[i] - other.values[i]).abs()); max_diff = max_diff.max((self.values[i] - other.values[i]).abs());
} }
max_diff max_diff
} }
fn sum(&self) -> T { fn sum(&self) -> T {
@@ -644,7 +662,11 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
fn softmax_mut(&mut self) { fn softmax_mut(&mut self) {
let max = self.values.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b)); let max = self
.values
.iter()
.map(|x| x.abs())
.fold(T::neg_infinity(), |a, b| a.max(b));
let mut z = T::zero(); let mut z = T::zero();
for r in 0..self.nrows { for r in 0..self.nrows {
for c in 0..self.ncols { for c in 0..self.ncols {
@@ -668,7 +690,6 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
fn argmax(&self) -> Vec<usize> { fn argmax(&self) -> Vec<usize> {
let mut res = vec![0usize; self.nrows]; let mut res = vec![0usize; self.nrows];
for r in 0..self.nrows { for r in 0..self.nrows {
@@ -676,16 +697,15 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
let mut max_pos = 0usize; let mut max_pos = 0usize;
for c in 0..self.ncols { for c in 0..self.ncols {
let v = self.get(r, c); let v = self.get(r, c);
if max < v{ if max < v {
max = v; max = v;
max_pos = c; max_pos = c;
} }
} }
res[r] = max_pos; res[r] = max_pos;
} }
res res
} }
fn unique(&self) -> Vec<T> { fn unique(&self) -> Vec<T> {
@@ -696,17 +716,16 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
} }
fn cov(&self) -> Self { fn cov(&self) -> Self {
let (m, n) = self.shape(); let (m, n) = self.shape();
let mu = self.column_mean(); let mu = self.column_mean();
let mut cov = Self::zeros(n, n); let mut cov = Self::zeros(n, n);
for k in 0..m { for k in 0..m {
for i in 0..n { for i in 0..n {
for j in 0..=i { for j in 0..=i {
cov.add_element_mut(i, j, (self.get(k, i) - mu[i]) * (self.get(k, j) - mu[j])); cov.add_element_mut(i, j, (self.get(k, i) - mu[i]) * (self.get(k, j) - mu[j]));
} }
} }
} }
@@ -715,127 +734,88 @@ impl<T: FloatExt> BaseMatrix<T> for DenseMatrix<T> {
for i in 0..n { for i in 0..n {
for j in 0..=i { for j in 0..=i {
cov.div_element_mut(i, j, m_t); cov.div_element_mut(i, j, m_t);
cov.set(j, i, cov.get(i, j)); cov.set(j, i, cov.get(i, j));
} }
} }
cov cov
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn from_to_row_vec() { fn from_to_row_vec() {
let vec = vec![1., 2., 3.];
let vec = vec![ 1., 2., 3.]; assert_eq!(
assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::new(1, 3, vec![1., 2., 3.])); DenseMatrix::from_row_vector(vec.clone()),
assert_eq!(DenseMatrix::from_row_vector(vec.clone()).to_row_vector(), vec![1., 2., 3.]); DenseMatrix::new(1, 3, vec![1., 2., 3.])
);
} assert_eq!(
DenseMatrix::from_row_vector(vec.clone()).to_row_vector(),
#[test] vec![1., 2., 3.]
fn h_stack() { );
let a = DenseMatrix::from_array(
&[
&[1., 2., 3.],
&[4., 5., 6.],
&[7., 8., 9.]]);
let b = DenseMatrix::from_array(
&[
&[1., 2., 3.],
&[4., 5., 6.]]);
let expected = DenseMatrix::from_array(
&[
&[1., 2., 3.],
&[4., 5., 6.],
&[7., 8., 9.],
&[1., 2., 3.],
&[4., 5., 6.]]);
let result = a.h_stack(&b);
assert_eq!(result, expected);
} }
#[test] #[test]
fn v_stack() { fn h_stack() {
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
let a = DenseMatrix::from_array( let b = DenseMatrix::from_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
&[ let expected = DenseMatrix::from_array(&[
&[1., 2., 3.], &[1., 2., 3.],
&[4., 5., 6.], &[4., 5., 6.],
&[7., 8., 9.]]); &[7., 8., 9.],
let b = DenseMatrix::from_array( &[1., 2., 3.],
&[ &[4., 5., 6.],
&[1., 2.], ]);
&[3., 4.], let result = a.h_stack(&b);
&[5., 6.]]); assert_eq!(result, expected);
let expected = DenseMatrix::from_array(
&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.]]);
let result = a.v_stack(&b);
assert_eq!(result, expected);
} }
#[test] #[test]
fn dot() { fn v_stack() {
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
let a = DenseMatrix::from_array( let b = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
&[ let expected = DenseMatrix::from_array(&[
&[1., 2., 3.], &[1., 2., 3., 1., 2.],
&[4., 5., 6.]]); &[4., 5., 6., 3., 4.],
let b = DenseMatrix::from_array( &[7., 8., 9., 5., 6.],
&[ ]);
&[1., 2.], let result = a.v_stack(&b);
&[3., 4.], assert_eq!(result, expected);
&[5., 6.]]);
let expected = DenseMatrix::from_array(
&[
&[22., 28.],
&[49., 64.]]);
let result = a.dot(&b);
assert_eq!(result, expected);
} }
#[test] #[test]
fn slice() { fn dot() {
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
let m = DenseMatrix::from_array( let b = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
&[ let expected = DenseMatrix::from_array(&[&[22., 28.], &[49., 64.]]);
&[1., 2., 3., 1., 2.], let result = a.dot(&b);
&[4., 5., 6., 3., 4.], assert_eq!(result, expected);
&[7., 8., 9., 5., 6.]]);
let expected = DenseMatrix::from_array(
&[
&[2., 3.],
&[5., 6.]]);
let result = m.slice(0..2, 1..3);
assert_eq!(result, expected);
} }
#[test] #[test]
fn approximate_eq() { fn slice() {
let m = DenseMatrix::from_array( let m = DenseMatrix::from_array(&[
&[ &[1., 2., 3., 1., 2.],
&[2., 3.], &[4., 5., 6., 3., 4.],
&[5., 6.]]); &[7., 8., 9., 5., 6.],
let m_eq = DenseMatrix::from_array( ]);
&[ let expected = DenseMatrix::from_array(&[&[2., 3.], &[5., 6.]]);
&[2.5, 3.0], let result = m.slice(0..2, 1..3);
&[5., 5.5]]); assert_eq!(result, expected);
let m_neq = DenseMatrix::from_array( }
&[
&[3.0, 3.0], #[test]
&[5., 6.5]]); fn approximate_eq() {
assert!(m.approximate_eq(&m_eq, 0.5)); let m = DenseMatrix::from_array(&[&[2., 3.], &[5., 6.]]);
assert!(!m.approximate_eq(&m_neq, 0.5)); let m_eq = DenseMatrix::from_array(&[&[2.5, 3.0], &[5., 5.5]]);
let m_neq = DenseMatrix::from_array(&[&[3.0, 3.0], &[5., 6.5]]);
assert!(m.approximate_eq(&m_eq, 0.5));
assert!(!m.approximate_eq(&m_neq, 0.5));
} }
#[test] #[test]
@@ -864,7 +844,7 @@ mod tests {
fn reshape() { fn reshape() {
let m_orig = DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6.]); let m_orig = DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6.]);
let m_2_by_3 = m_orig.reshape(2, 3); let m_2_by_3 = m_orig.reshape(2, 3);
let m_result = m_2_by_3.reshape(1, 6); let m_result = m_2_by_3.reshape(1, 6);
assert_eq!(m_2_by_3.shape(), (2, 3)); assert_eq!(m_2_by_3.shape(), (2, 3));
assert_eq!(m_2_by_3.get(1, 1), 5.); assert_eq!(m_2_by_3.get(1, 1), 5.);
assert_eq!(m_result.get(0, 1), 2.); assert_eq!(m_result.get(0, 1), 2.);
@@ -872,70 +852,76 @@ mod tests {
} }
#[test] #[test]
fn norm() { fn norm() {
let v = DenseMatrix::vector_from_array(&[3., -2., 6.]);
let v = DenseMatrix::vector_from_array(&[3., -2., 6.]); assert_eq!(v.norm(1.), 11.);
assert_eq!(v.norm(1.), 11.); assert_eq!(v.norm(2.), 7.);
assert_eq!(v.norm(2.), 7.); assert_eq!(v.norm(std::f64::INFINITY), 6.);
assert_eq!(v.norm(std::f64::INFINITY), 6.); assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.);
assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.);
} }
#[test] #[test]
fn softmax_mut() { fn softmax_mut() {
let mut prob: DenseMatrix<f64> = DenseMatrix::vector_from_array(&[1., 2., 3.]);
prob.softmax_mut();
assert!((prob.get(0, 0) - 0.09).abs() < 0.01);
assert!((prob.get(0, 1) - 0.24).abs() < 0.01);
assert!((prob.get(0, 2) - 0.66).abs() < 0.01);
}
let mut prob: DenseMatrix<f64> = DenseMatrix::vector_from_array(&[1., 2., 3.]);
prob.softmax_mut();
assert!((prob.get(0, 0) - 0.09).abs() < 0.01);
assert!((prob.get(0, 1) - 0.24).abs() < 0.01);
assert!((prob.get(0, 2) - 0.66).abs() < 0.01);
}
#[test] #[test]
fn col_mean(){ fn col_mean() {
let a = DenseMatrix::from_array(&[ let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
&[1., 2., 3.],
&[4., 5., 6.],
&[7., 8., 9.]]);
let res = a.column_mean(); let res = a.column_mean();
assert_eq!(res, vec![4., 5., 6.]); assert_eq!(res, vec![4., 5., 6.]);
} }
#[test] #[test]
fn eye(){ fn eye() {
let a = DenseMatrix::from_array(&[ let a = DenseMatrix::from_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0., 0., 1.]]);
&[1., 0., 0.],
&[0., 1., 0.],
&[0., 0., 1.]]);
let res = DenseMatrix::eye(3); let res = DenseMatrix::eye(3);
assert_eq!(res, a); assert_eq!(res, a);
} }
#[test] #[test]
fn to_from_json() { fn to_from_json() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let deserialized_a: DenseMatrix<f64> = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap(); let deserialized_a: DenseMatrix<f64> =
assert_eq!(a, deserialized_a); serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
assert_eq!(a, deserialized_a);
} }
#[test] #[test]
fn to_from_bincode() { fn to_from_bincode() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let deserialized_a: DenseMatrix<f64> = bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap(); let deserialized_a: DenseMatrix<f64> =
assert_eq!(a, deserialized_a); bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
assert_eq!(a, deserialized_a);
} }
#[test] #[test]
fn to_string() { fn to_string() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
assert_eq!(format!("{}", a), "[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]"); assert_eq!(
format!("{}", a),
"[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]"
);
} }
#[test] #[test]
fn cov() { fn cov() {
let a = DenseMatrix::from_array(&[&[64.0, 580.0, 29.0], &[66.0, 570.0, 33.0], &[68.0, 590.0, 37.0], &[69.0, 660.0, 46.0], &[73.0, 600.0, 55.0]]); let a = DenseMatrix::from_array(&[
let expected = DenseMatrix::from_array(&[&[11.5, 50.0, 34.75], &[50.0, 1250.0, 205.0], &[34.75, 205.0, 110.0]]); &[64.0, 580.0, 29.0],
assert_eq!(a.cov(), expected); &[66.0, 570.0, 33.0],
&[68.0, 590.0, 37.0],
&[69.0, 660.0, 46.0],
&[73.0, 600.0, 55.0],
]);
let expected = DenseMatrix::from_array(&[
&[11.5, 50.0, 34.75],
&[50.0, 1250.0, 205.0],
&[34.75, 205.0, 110.0],
]);
assert_eq!(a.cov(), expected);
} }
} }
+1 -1
View File
@@ -1 +1 @@
pub mod dense_matrix; pub mod dense_matrix;
+185 -217
View File
@@ -1,38 +1,39 @@
use std::ops::{Range, AddAssign, SubAssign, MulAssign, DivAssign};
use std::iter::Sum; use std::iter::Sum;
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
use nalgebra::{MatrixMN, DMatrix, Matrix, Scalar, Dynamic, U1, VecStorage}; use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, Scalar, VecStorage, U1};
use crate::math::num::FloatExt;
use crate::linalg::{BaseMatrix, BaseVector};
use crate::linalg::Matrix as SmartCoreMatrix;
use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use crate::linalg::lu::LUDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::Matrix as SmartCoreMatrix;
use crate::linalg::{BaseMatrix, BaseVector};
use crate::math::num::FloatExt;
impl<T: FloatExt + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> { impl<T: FloatExt + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
fn get(&self, i: usize) -> T { fn get(&self, i: usize) -> T {
*self.get((0, i)).unwrap() *self.get((0, i)).unwrap()
} }
fn set(&mut self, i: usize, x: T){ fn set(&mut self, i: usize, x: T) {
*self.get_mut((0, i)).unwrap() = x; *self.get_mut((0, i)).unwrap() = x;
} }
fn len(&self) -> usize{ fn len(&self) -> usize {
self.len() self.len()
} }
} }
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> BaseMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
BaseMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
{ {
type RowVector = MatrixMN<T, U1, Dynamic>; type RowVector = MatrixMN<T, U1, Dynamic>;
fn from_row_vector(vec: Self::RowVector) -> Self{ fn from_row_vector(vec: Self::RowVector) -> Self {
Matrix::from_rows(&[vec]) Matrix::from_rows(&[vec])
} }
fn to_row_vector(self) -> Self::RowVector{ fn to_row_vector(self) -> Self::RowVector {
self.row(0).into_owned() self.row(0).into_owned()
} }
@@ -50,13 +51,13 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
fn set(&mut self, row: usize, col: usize, x: T) { fn set(&mut self, row: usize, col: usize, x: T) {
*self.get_mut((row, col)).unwrap() = x; *self.get_mut((row, col)).unwrap() = x;
} }
fn eye(size: usize) -> Self { fn eye(size: usize) -> Self {
DMatrix::identity(size, size) DMatrix::identity(size, size)
} }
fn zeros(nrows: usize, ncols: usize) -> Self { fn zeros(nrows: usize, ncols: usize) -> Self {
DMatrix::zeros(nrows, ncols) DMatrix::zeros(nrows, ncols)
} }
@@ -70,7 +71,7 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
for (i, row) in self.row_iter().enumerate() { for (i, row) in self.row_iter().enumerate() {
for (j, v) in row.iter().enumerate() { for (j, v) in row.iter().enumerate() {
result[i * ncols + j] = *v; result[i * ncols + j] = *v;
} }
} }
result result
} }
@@ -83,25 +84,25 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
fn shape(&self) -> (usize, usize) { fn shape(&self) -> (usize, usize) {
self.shape() self.shape()
} }
fn v_stack(&self, other: &Self) -> Self { fn v_stack(&self, other: &Self) -> Self {
let mut columns = Vec::new(); let mut columns = Vec::new();
for r in 0..self.ncols(){ for r in 0..self.ncols() {
columns.push(self.column(r)); columns.push(self.column(r));
} }
for r in 0..other.ncols(){ for r in 0..other.ncols() {
columns.push(other.column(r)); columns.push(other.column(r));
} }
Matrix::from_columns(&columns) Matrix::from_columns(&columns)
} }
fn h_stack(&self, other: &Self) -> Self { fn h_stack(&self, other: &Self) -> Self {
let mut rows = Vec::new(); let mut rows = Vec::new();
for r in 0..self.nrows(){ for r in 0..self.nrows() {
rows.push(self.row(r)); rows.push(self.row(r));
} }
for r in 0..other.nrows(){ for r in 0..other.nrows() {
rows.push(other.row(r)); rows.push(other.row(r));
} }
Matrix::from_rows(&rows) Matrix::from_rows(&rows)
@@ -109,11 +110,11 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
fn dot(&self, other: &Self) -> Self { fn dot(&self, other: &Self) -> Self {
self * other self * other
} }
fn vector_dot(&self, other: &Self) -> T { fn vector_dot(&self, other: &Self) -> T {
self.dot(other) self.dot(other)
} }
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self { fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self {
self.slice_range(rows, cols).into_owned() self.slice_range(rows, cols).into_owned()
@@ -141,46 +142,44 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
self self
} }
fn div_mut(&mut self, other: &Self) -> &Self{ fn div_mut(&mut self, other: &Self) -> &Self {
self.component_div_assign(other); self.component_div_assign(other);
self self
} }
fn add_scalar_mut(&mut self, scalar: T) -> &Self{ fn add_scalar_mut(&mut self, scalar: T) -> &Self {
Matrix::add_scalar_mut(self, scalar); Matrix::add_scalar_mut(self, scalar);
self self
} }
fn sub_scalar_mut(&mut self, scalar: T) -> &Self{ fn sub_scalar_mut(&mut self, scalar: T) -> &Self {
Matrix::add_scalar_mut(self, -scalar); Matrix::add_scalar_mut(self, -scalar);
self self
} }
fn mul_scalar_mut(&mut self, scalar: T) -> &Self{ fn mul_scalar_mut(&mut self, scalar: T) -> &Self {
*self *= scalar; *self *= scalar;
self self
} }
fn div_scalar_mut(&mut self, scalar: T) -> &Self{ fn div_scalar_mut(&mut self, scalar: T) -> &Self {
*self /= scalar; *self /= scalar;
self self
} }
fn transpose(&self) -> Self{ fn transpose(&self) -> Self {
self.transpose() self.transpose()
} }
fn rand(nrows: usize, ncols: usize) -> Self{ fn rand(nrows: usize, ncols: usize) -> Self {
DMatrix::from_iterator(nrows, ncols, (0..nrows*ncols).map(|_| { DMatrix::from_iterator(nrows, ncols, (0..nrows * ncols).map(|_| T::rand()))
T::rand()
}))
} }
fn norm2(&self) -> T{ fn norm2(&self) -> T {
self.iter().map(|x| *x * *x).sum::<T>().sqrt() self.iter().map(|x| *x * *x).sum::<T>().sqrt()
} }
fn norm(&self, p:T) -> T { fn norm(&self, p: T) -> T {
if p.is_infinite() && p.is_sign_positive() { if p.is_infinite() && p.is_sign_positive() {
self.iter().fold(T::neg_infinity(), |f, &val| { self.iter().fold(T::neg_infinity(), |f, &val| {
let v = val.abs(); let v = val.abs();
@@ -189,7 +188,7 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
} else { } else {
v v
} }
}) })
} else if p.is_infinite() && p.is_sign_negative() { } else if p.is_infinite() && p.is_sign_negative() {
self.iter().fold(T::infinity(), |f, &val| { self.iter().fold(T::infinity(), |f, &val| {
let v = val.abs(); let v = val.abs();
@@ -200,19 +199,17 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
} }
}) })
} else { } else {
let mut norm = T::zero(); let mut norm = T::zero();
for xi in self.iter() { for xi in self.iter() {
norm = norm + xi.abs().powf(p); norm = norm + xi.abs().powf(p);
} }
norm.powf(T::one()/p) norm.powf(T::one() / p)
} }
} }
fn column_mean(&self) -> Vec<T> { fn column_mean(&self) -> Vec<T> {
let mut res = Vec::new(); let mut res = Vec::new();
for column in self.column_iter() { for column in self.column_iter() {
@@ -221,68 +218,71 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
for v in column.iter() { for v in column.iter() {
sum += *v; sum += *v;
count += 1; count += 1;
} }
res.push(sum / T::from(count).unwrap()); res.push(sum / T::from(count).unwrap());
} }
res res
} }
fn div_element_mut(&mut self, row: usize, col: usize, x: T){ fn div_element_mut(&mut self, row: usize, col: usize, x: T) {
*self.get_mut((row, col)).unwrap() = *self.get((row, col)).unwrap() / x; *self.get_mut((row, col)).unwrap() = *self.get((row, col)).unwrap() / x;
} }
fn mul_element_mut(&mut self, row: usize, col: usize, x: T){ fn mul_element_mut(&mut self, row: usize, col: usize, x: T) {
*self.get_mut((row, col)).unwrap() = *self.get((row, col)).unwrap() * x; *self.get_mut((row, col)).unwrap() = *self.get((row, col)).unwrap() * x;
} }
fn add_element_mut(&mut self, row: usize, col: usize, x: T){ fn add_element_mut(&mut self, row: usize, col: usize, x: T) {
*self.get_mut((row, col)).unwrap() = *self.get((row, col)).unwrap() + x; *self.get_mut((row, col)).unwrap() = *self.get((row, col)).unwrap() + x;
} }
fn sub_element_mut(&mut self, row: usize, col: usize, x: T){ fn sub_element_mut(&mut self, row: usize, col: usize, x: T) {
*self.get_mut((row, col)).unwrap() = *self.get((row, col)).unwrap() - x; *self.get_mut((row, col)).unwrap() = *self.get((row, col)).unwrap() - x;
} }
fn negative_mut(&mut self){ fn negative_mut(&mut self) {
*self *= -T::one(); *self *= -T::one();
} }
fn reshape(&self, nrows: usize, ncols: usize) -> Self{ fn reshape(&self, nrows: usize, ncols: usize) -> Self {
DMatrix::from_row_slice(nrows, ncols, &self.to_raw_vector()) DMatrix::from_row_slice(nrows, ncols, &self.to_raw_vector())
} }
fn copy_from(&mut self, other: &Self){ fn copy_from(&mut self, other: &Self) {
Matrix::copy_from(self, other); Matrix::copy_from(self, other);
} }
fn abs_mut(&mut self) -> &Self{ fn abs_mut(&mut self) -> &Self {
for v in self.iter_mut(){ for v in self.iter_mut() {
*v = v.abs() *v = v.abs()
} }
self self
} }
fn sum(&self) -> T{ fn sum(&self) -> T {
let mut sum = T::zero(); let mut sum = T::zero();
for v in self.iter(){ for v in self.iter() {
sum += *v; sum += *v;
} }
sum sum
} }
fn max_diff(&self, other: &Self) -> T{ fn max_diff(&self, other: &Self) -> T {
let mut max_diff = T::zero(); let mut max_diff = T::zero();
for r in 0..self.nrows() { for r in 0..self.nrows() {
for c in 0..self.ncols() { for c in 0..self.ncols() {
max_diff = max_diff.max((self[(r, c)] - other[(r, c)]).abs()); max_diff = max_diff.max((self[(r, c)] - other[(r, c)]).abs());
} }
} }
max_diff max_diff
} }
fn softmax_mut(&mut self){ fn softmax_mut(&mut self) {
let max = self.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b)); let max = self
.iter()
.map(|x| x.abs())
.fold(T::neg_infinity(), |a, b| a.max(b));
let mut z = T::zero(); let mut z = T::zero();
for r in 0..self.nrows() { for r in 0..self.nrows() {
for c in 0..self.ncols() { for c in 0..self.ncols() {
@@ -298,14 +298,14 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
} }
} }
fn pow_mut(&mut self, p: T) -> &Self{ fn pow_mut(&mut self, p: T) -> &Self {
for v in self.iter_mut(){ for v in self.iter_mut() {
*v = v.powf(p) *v = v.powf(p)
} }
self self
} }
fn argmax(&self) -> Vec<usize>{ fn argmax(&self) -> Vec<usize> {
let mut res = vec![0usize; self.nrows()]; let mut res = vec![0usize; self.nrows()];
for r in 0..self.nrows() { for r in 0..self.nrows() {
@@ -315,16 +315,15 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
let v = self[(r, c)]; let v = self[(r, c)];
if max < v { if max < v {
max = v; max = v;
max_pos = c; max_pos = c;
} }
} }
res[r] = max_pos; res[r] = max_pos;
} }
res res
} }
fn unique(&self) -> Vec<T> { fn unique(&self) -> Vec<T> {
let mut result: Vec<T> = self.iter().map(|v| *v).collect(); let mut result: Vec<T> = self.iter().map(|v| *v).collect();
result.sort_by(|a, b| a.partial_cmp(b).unwrap()); result.sort_by(|a, b| a.partial_cmp(b).unwrap());
@@ -335,95 +334,96 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
fn cov(&self) -> Self { fn cov(&self) -> Self {
panic!("Not implemented"); panic!("Not implemented");
} }
} }
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> SVDDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> {} impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
SVDDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
{
}
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> EVDDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> {} impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
EVDDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
{
}
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> QRDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> {} impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
QRDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
{
}
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> LUDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> {} impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
LUDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
{
}
impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static> SmartCoreMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>> {} impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
SmartCoreMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
{
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use nalgebra::{Matrix2x3, DMatrix, RowDVector}; use nalgebra::{DMatrix, Matrix2x3, RowDVector};
#[test] #[test]
fn vec_len() { fn vec_len() {
let v = RowDVector::from_vec(vec!(1., 2., 3.)); let v = RowDVector::from_vec(vec![1., 2., 3.]);
assert_eq!(3, v.len()); assert_eq!(3, v.len());
} }
#[test] #[test]
fn get_set_vector() { fn get_set_vector() {
let mut v = RowDVector::from_vec(vec!(1., 2., 3., 4.)); let mut v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
let expected = RowDVector::from_vec(vec!(1., 5., 3., 4.)); let expected = RowDVector::from_vec(vec![1., 5., 3., 4.]);
v.set(1, 5.); v.set(1, 5.);
assert_eq!(v, expected); assert_eq!(v, expected);
assert_eq!(5., BaseVector::get(&v, 1)); assert_eq!(5., BaseVector::get(&v, 1));
} }
#[test] #[test]
fn get_set_dynamic() { fn get_set_dynamic() {
let mut m = DMatrix::from_row_slice( let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2,
3,
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
);
let expected = Matrix2x3::new(1., 2., 3., 4., let expected = Matrix2x3::new(1., 2., 3., 4., 10., 6.);
10., 6.);
m.set(1, 1, 10.); m.set(1, 1, 10.);
assert_eq!(m, expected); assert_eq!(m, expected);
assert_eq!(10., BaseMatrix::get(&m, 1, 1)); assert_eq!(10., BaseMatrix::get(&m, 1, 1));
} }
#[test] #[test]
fn zeros() { fn zeros() {
let expected = DMatrix::from_row_slice( let expected = DMatrix::from_row_slice(2, 2, &[0., 0., 0., 0.]);
2,
2,
&[0., 0., 0., 0.],
);
let m:DMatrix<f64> = BaseMatrix::zeros(2, 2); let m: DMatrix<f64> = BaseMatrix::zeros(2, 2);
assert_eq!(m, expected); assert_eq!(m, expected);
} }
#[test] #[test]
fn ones() { fn ones() {
let expected = DMatrix::from_row_slice( let expected = DMatrix::from_row_slice(2, 2, &[1., 1., 1., 1.]);
2,
2,
&[1., 1., 1., 1.],
);
let m:DMatrix<f64> = BaseMatrix::ones(2, 2); let m: DMatrix<f64> = BaseMatrix::ones(2, 2);
assert_eq!(m, expected); assert_eq!(m, expected);
} }
#[test] #[test]
fn eye(){ fn eye() {
let expected = DMatrix::from_row_slice(3, 3, &[1., 0., 0., 0., 1., 0., 0., 0., 1.]); let expected = DMatrix::from_row_slice(3, 3, &[1., 0., 0., 0., 1., 0., 0., 0., 1.]);
let m: DMatrix<f64> = BaseMatrix::eye(3); let m: DMatrix<f64> = BaseMatrix::eye(3);
assert_eq!(m, expected); assert_eq!(m, expected);
} }
#[test] #[test]
fn shape() { fn shape() {
let m:DMatrix<f64> = BaseMatrix::zeros(5, 10); let m: DMatrix<f64> = BaseMatrix::zeros(5, 10);
let (nrows, ncols) = m.shape(); let (nrows, ncols) = m.shape();
assert_eq!(nrows, 5); assert_eq!(nrows, 5);
@@ -431,18 +431,10 @@ mod tests {
} }
#[test] #[test]
fn scalar_add_sub_mul_div(){ fn scalar_add_sub_mul_div() {
let mut m = DMatrix::from_row_slice( let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2,
3,
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
);
let expected = DMatrix::from_row_slice( let expected = DMatrix::from_row_slice(2, 3, &[0.6, 0.8, 1., 1.2, 1.4, 1.6]);
2,
3,
&[0.6, 0.8, 1., 1.2, 1.4, 1.6],
);
m.add_scalar_mut(3.0); m.add_scalar_mut(3.0);
m.sub_scalar_mut(1.0); m.sub_scalar_mut(1.0);
@@ -452,79 +444,51 @@ mod tests {
} }
#[test] #[test]
fn add_sub_mul_div(){ fn add_sub_mul_div() {
let mut m = DMatrix::from_row_slice( let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
2,
2, let a = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
&[1.0, 2.0, 3.0, 4.0],
);
let a = DMatrix::from_row_slice(
2,
2,
&[1.0, 2.0, 3.0, 4.0],
);
let b: DMatrix<f64> = BaseMatrix::fill(2, 2, 10.); let b: DMatrix<f64> = BaseMatrix::fill(2, 2, 10.);
let expected = DMatrix::from_row_slice( let expected = DMatrix::from_row_slice(2, 2, &[0.1, 0.6, 1.5, 2.8]);
2,
2,
&[0.1, 0.6, 1.5, 2.8],
);
m.add_mut(&a); m.add_mut(&a);
m.mul_mut(&a); m.mul_mut(&a);
m.sub_mut(&a); m.sub_mut(&a);
m.div_mut(&b); m.div_mut(&b);
assert_eq!(m, expected); assert_eq!(m, expected);
} }
#[test] #[test]
fn to_from_row_vector(){ fn to_from_row_vector() {
let v = RowDVector::from_vec(vec!(1., 2., 3., 4.)); let v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
let expected = v.clone(); let expected = v.clone();
let m: DMatrix<f64> = BaseMatrix::from_row_vector(v); let m: DMatrix<f64> = BaseMatrix::from_row_vector(v);
assert_eq!(m.to_row_vector(), expected); assert_eq!(m.to_row_vector(), expected);
} }
#[test] #[test]
fn get_row_col_as_vec(){ fn get_row_col_as_vec() {
let m = DMatrix::from_row_slice( let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
3,
3,
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
);
assert_eq!(m.get_row_as_vec(1), vec!(4., 5., 6.)); assert_eq!(m.get_row_as_vec(1), vec!(4., 5., 6.));
assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.)); assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.));
} }
#[test] #[test]
fn to_raw_vector(){ fn to_raw_vector() {
let m = DMatrix::from_row_slice( let m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2,
3,
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
);
assert_eq!(m.to_raw_vector(), vec!(1., 2., 3., 4., 5., 6.)); assert_eq!(m.to_raw_vector(), vec!(1., 2., 3., 4., 5., 6.));
} }
#[test] #[test]
fn element_add_sub_mul_div(){ fn element_add_sub_mul_div() {
let mut m = DMatrix::from_row_slice( let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
2,
2,
&[1.0, 2.0, 3.0, 4.0],
);
let expected = DMatrix::from_row_slice( let expected = DMatrix::from_row_slice(2, 2, &[4., 1., 6., 0.4]);
2,
2,
&[4., 1., 6., 0.4],
);
m.add_element_mut(0, 0, 3.0); m.add_element_mut(0, 0, 3.0);
m.sub_element_mut(0, 1, 1.0); m.sub_element_mut(0, 1, 1.0);
@@ -534,60 +498,65 @@ mod tests {
} }
#[test] #[test]
fn vstack_hstack() { fn vstack_hstack() {
let m1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
let m2 = DMatrix::from_row_slice(2, 1, &[7., 8.]);
let m1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]); let m3 = DMatrix::from_row_slice(1, 4, &[9., 10., 11., 12.]);
let m2 = DMatrix::from_row_slice(2, 1, &[ 7., 8.]);
let m3 = DMatrix::from_row_slice(1, 4, &[9., 10., 11., 12.]); let expected =
DMatrix::from_row_slice(3, 4, &[1., 2., 3., 7., 4., 5., 6., 8., 9., 10., 11., 12.]);
let expected = DMatrix::from_row_slice(3, 4, &[1., 2., 3., 7., 4., 5., 6., 8., 9., 10., 11., 12.]); let result = m1.v_stack(&m2).h_stack(&m3);
let result = m1.v_stack(&m2).h_stack(&m3);
assert_eq!(result, expected);
assert_eq!(result, expected);
} }
#[test] #[test]
fn dot() { fn dot() {
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]); let b = DMatrix::from_row_slice(3, 2, &[1., 2., 3., 4., 5., 6.]);
let b = DMatrix::from_row_slice(3, 2, &[1., 2., 3., 4., 5., 6.]); let expected = DMatrix::from_row_slice(2, 2, &[22., 28., 49., 64.]);
let expected = DMatrix::from_row_slice(2, 2, &[22., 28., 49., 64.]); let result = BaseMatrix::dot(&a, &b);
let result = BaseMatrix::dot(&a, &b); assert_eq!(result, expected);
assert_eq!(result, expected);
}
#[test]
fn vector_dot() {
let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
let b = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
assert_eq!(14., a.vector_dot(&b));
} }
#[test] #[test]
fn slice() { fn vector_dot() {
let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
let b = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
assert_eq!(14., a.vector_dot(&b));
}
let a = DMatrix::from_row_slice(3, 5, &[1., 2., 3., 1., 2., 4., 5., 6., 3., 4., 7., 8., 9., 5., 6.]); #[test]
let expected = DMatrix::from_row_slice(2, 2, &[2., 3., 5., 6.]); fn slice() {
let result = BaseMatrix::slice(&a, 0..2, 1..3); let a = DMatrix::from_row_slice(
assert_eq!(result, expected); 3,
5,
&[1., 2., 3., 1., 2., 4., 5., 6., 3., 4., 7., 8., 9., 5., 6.],
);
let expected = DMatrix::from_row_slice(2, 2, &[2., 3., 5., 6.]);
let result = BaseMatrix::slice(&a, 0..2, 1..3);
assert_eq!(result, expected);
} }
#[test] #[test]
fn approximate_eq() { fn approximate_eq() {
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]); let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
let noise = DMatrix::from_row_slice(3, 3, &[1e-5, 2e-5, 3e-5, 4e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5]); let noise = DMatrix::from_row_slice(
3,
3,
&[1e-5, 2e-5, 3e-5, 4e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5],
);
assert!(a.approximate_eq(&(&noise + &a), 1e-4)); assert!(a.approximate_eq(&(&noise + &a), 1e-4));
assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
} }
#[test] #[test]
fn negative_mut() { fn negative_mut() {
let mut v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]); let mut v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]);
v.negative_mut(); v.negative_mut();
assert_eq!(v, DMatrix::from_row_slice(1, 3, &[-3., 2., -6.])); assert_eq!(v, DMatrix::from_row_slice(1, 3, &[-3., 2., -6.]));
} }
#[test] #[test]
@@ -595,7 +564,7 @@ mod tests {
let m = DMatrix::from_row_slice(2, 2, &[1.0, 3.0, 2.0, 4.0]); let m = DMatrix::from_row_slice(2, 2, &[1.0, 3.0, 2.0, 4.0]);
let expected = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]); let expected = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let m_transposed = m.transpose(); let m_transposed = m.transpose();
assert_eq!(m_transposed, expected); assert_eq!(m_transposed, expected);
} }
#[test] #[test]
@@ -609,8 +578,8 @@ mod tests {
} }
#[test] #[test]
fn norm() { fn norm() {
let v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]); let v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]);
assert_eq!(BaseMatrix::norm(&v, 1.), 11.); assert_eq!(BaseMatrix::norm(&v, 1.), 11.);
assert_eq!(BaseMatrix::norm(&v, 2.), 7.); assert_eq!(BaseMatrix::norm(&v, 2.), 7.);
assert_eq!(BaseMatrix::norm(&v, std::f64::INFINITY), 6.); assert_eq!(BaseMatrix::norm(&v, std::f64::INFINITY), 6.);
@@ -618,17 +587,17 @@ mod tests {
} }
#[test] #[test]
fn col_mean(){ fn col_mean() {
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]); let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
let res = BaseMatrix::column_mean(&a); let res = BaseMatrix::column_mean(&a);
assert_eq!(res, vec![4., 5., 6.]); assert_eq!(res, vec![4., 5., 6.]);
} }
#[test] #[test]
fn reshape() { fn reshape() {
let m_orig = DMatrix::from_row_slice(1, 6, &[1., 2., 3., 4., 5., 6.]); let m_orig = DMatrix::from_row_slice(1, 6, &[1., 2., 3., 4., 5., 6.]);
let m_2_by_3 = m_orig.reshape(2, 3); let m_2_by_3 = m_orig.reshape(2, 3);
let m_result = m_2_by_3.reshape(1, 6); let m_result = m_2_by_3.reshape(1, 6);
assert_eq!(BaseMatrix::shape(&m_2_by_3), (2, 3)); assert_eq!(BaseMatrix::shape(&m_2_by_3), (2, 3));
assert_eq!(BaseMatrix::get(&m_2_by_3, 1, 1), 5.); assert_eq!(BaseMatrix::get(&m_2_by_3, 1, 1), 5.);
assert_eq!(BaseMatrix::get(&m_result, 0, 1), 2.); assert_eq!(BaseMatrix::get(&m_result, 0, 1), 2.);
@@ -653,47 +622,46 @@ mod tests {
#[test] #[test]
fn sum() { fn sum() {
let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
assert_eq!(a.sum(), 6.); assert_eq!(a.sum(), 6.);
} }
#[test] #[test]
fn max_diff() { fn max_diff() {
let a1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., -5., 6.]); let a1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., -5., 6.]);
let a2 = DMatrix::from_row_slice(2, 3, &[2., 3., 4., 1., 0., -12.]); let a2 = DMatrix::from_row_slice(2, 3, &[2., 3., 4., 1., 0., -12.]);
assert_eq!(a1.max_diff(&a2), 18.); assert_eq!(a1.max_diff(&a2), 18.);
assert_eq!(a2.max_diff(&a2), 0.); assert_eq!(a2.max_diff(&a2), 0.);
} }
#[test] #[test]
fn softmax_mut(){ fn softmax_mut() {
let mut prob: DMatrix<f64> = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); let mut prob: DMatrix<f64> = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
prob.softmax_mut(); prob.softmax_mut();
assert!((BaseMatrix::get(&prob, 0, 0) - 0.09).abs() < 0.01); assert!((BaseMatrix::get(&prob, 0, 0) - 0.09).abs() < 0.01);
assert!((BaseMatrix::get(&prob, 0, 1) - 0.24).abs() < 0.01); assert!((BaseMatrix::get(&prob, 0, 1) - 0.24).abs() < 0.01);
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01); assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
} }
#[test] #[test]
fn pow_mut(){ fn pow_mut() {
let mut a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); let mut a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
a.pow_mut(3.); a.pow_mut(3.);
assert_eq!(a, DMatrix::from_row_slice(1, 3, &[1., 8., 27.])); assert_eq!(a, DMatrix::from_row_slice(1, 3, &[1., 8., 27.]));
} }
#[test] #[test]
fn argmax(){ fn argmax() {
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., -5., -6., -7., 0.1, 0.2, 0.1]); let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., -5., -6., -7., 0.1, 0.2, 0.1]);
let res = a.argmax(); let res = a.argmax();
assert_eq!(res, vec![2, 0, 1]); assert_eq!(res, vec![2, 0, 1]);
} }
#[test] #[test]
fn unique(){ fn unique() {
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]); let a = DMatrix::from_row_slice(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]);
let res = a.unique(); let res = a.unique();
assert_eq!(res.len(), 7); assert_eq!(res.len(), 7);
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]); assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
} }
}
}
+191 -225
View File
@@ -1,44 +1,45 @@
use std::ops::Range;
use std::iter::Sum; use std::iter::Sum;
use std::ops::AddAssign; use std::ops::AddAssign;
use std::ops::SubAssign;
use std::ops::MulAssign;
use std::ops::DivAssign; use std::ops::DivAssign;
use std::ops::MulAssign;
use std::ops::Range;
use std::ops::SubAssign;
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
use ndarray::ScalarOperand; use ndarray::ScalarOperand;
use ndarray::{s, stack, Array, ArrayBase, Axis, Ix1, Ix2, OwnedRepr};
use crate::math::num::FloatExt;
use crate::linalg::{BaseMatrix, BaseVector};
use crate::linalg::Matrix;
use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use crate::linalg::lu::LUDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::Matrix;
use crate::linalg::{BaseMatrix, BaseVector};
use crate::math::num::FloatExt;
impl<T: FloatExt> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix1> { impl<T: FloatExt> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix1> {
fn get(&self, i: usize) -> T { fn get(&self, i: usize) -> T {
self[i] self[i]
} }
fn set(&mut self, i: usize, x: T){ fn set(&mut self, i: usize, x: T) {
self[i] = x; self[i] = x;
} }
fn len(&self) -> usize{ fn len(&self) -> usize {
self.len() self.len()
} }
} }
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
{ {
type RowVector = ArrayBase<OwnedRepr<T>, Ix1>; type RowVector = ArrayBase<OwnedRepr<T>, Ix1>;
fn from_row_vector(vec: Self::RowVector) -> Self{ fn from_row_vector(vec: Self::RowVector) -> Self {
let vec_size = vec.len(); let vec_size = vec.len();
vec.into_shape((1, vec_size)).unwrap() vec.into_shape((1, vec_size)).unwrap()
} }
fn to_row_vector(self) -> Self::RowVector{ fn to_row_vector(self) -> Self::RowVector {
let vec_size = self.nrows() * self.ncols(); let vec_size = self.nrows() * self.ncols();
self.into_shape(vec_size).unwrap() self.into_shape(vec_size).unwrap()
} }
@@ -57,7 +58,7 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
fn set(&mut self, row: usize, col: usize, x: T) { fn set(&mut self, row: usize, col: usize, x: T) {
self[[row, col]] = x; self[[row, col]] = x;
} }
fn eye(size: usize) -> Self { fn eye(size: usize) -> Self {
Array::eye(size) Array::eye(size)
@@ -81,7 +82,7 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
fn shape(&self) -> (usize, usize) { fn shape(&self) -> (usize, usize) {
(self.nrows(), self.ncols()) (self.nrows(), self.ncols())
} }
fn v_stack(&self, other: &Self) -> Self { fn v_stack(&self, other: &Self) -> Self {
stack(Axis(1), &[self.view(), other.view()]).unwrap() stack(Axis(1), &[self.view(), other.view()]).unwrap()
@@ -92,12 +93,12 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
} }
fn dot(&self, other: &Self) -> Self { fn dot(&self, other: &Self) -> Self {
self.dot(other) self.dot(other)
} }
fn vector_dot(&self, other: &Self) -> T { fn vector_dot(&self, other: &Self) -> T {
self.dot(&other.view().reversed_axes())[[0, 0]] self.dot(&other.view().reversed_axes())[[0, 0]]
} }
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self { fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self {
self.slice(s![rows, cols]).to_owned() self.slice(s![rows, cols]).to_owned()
@@ -109,7 +110,7 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
fn add_mut(&mut self, other: &Self) -> &Self { fn add_mut(&mut self, other: &Self) -> &Self {
*self += other; *self += other;
self self
} }
fn sub_mut(&mut self, other: &Self) -> &Self { fn sub_mut(&mut self, other: &Self) -> &Self {
@@ -119,50 +120,48 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
fn mul_mut(&mut self, other: &Self) -> &Self { fn mul_mut(&mut self, other: &Self) -> &Self {
*self *= other; *self *= other;
self self
} }
fn div_mut(&mut self, other: &Self) -> &Self{ fn div_mut(&mut self, other: &Self) -> &Self {
*self /= other; *self /= other;
self self
} }
fn add_scalar_mut(&mut self, scalar: T) -> &Self{ fn add_scalar_mut(&mut self, scalar: T) -> &Self {
*self += scalar; *self += scalar;
self self
} }
fn sub_scalar_mut(&mut self, scalar: T) -> &Self{ fn sub_scalar_mut(&mut self, scalar: T) -> &Self {
*self -= scalar; *self -= scalar;
self self
} }
fn mul_scalar_mut(&mut self, scalar: T) -> &Self{ fn mul_scalar_mut(&mut self, scalar: T) -> &Self {
*self *= scalar; *self *= scalar;
self self
} }
fn div_scalar_mut(&mut self, scalar: T) -> &Self{ fn div_scalar_mut(&mut self, scalar: T) -> &Self {
*self /= scalar; *self /= scalar;
self self
} }
fn transpose(&self) -> Self{ fn transpose(&self) -> Self {
self.clone().reversed_axes() self.clone().reversed_axes()
} }
fn rand(nrows: usize, ncols: usize) -> Self{ fn rand(nrows: usize, ncols: usize) -> Self {
let values: Vec<T> = (0..nrows*ncols).map(|_| { let values: Vec<T> = (0..nrows * ncols).map(|_| T::rand()).collect();
T::rand()
}).collect();
Array::from_shape_vec((nrows, ncols), values).unwrap() Array::from_shape_vec((nrows, ncols), values).unwrap()
} }
fn norm2(&self) -> T{ fn norm2(&self) -> T {
self.iter().map(|x| *x * *x).sum::<T>().sqrt() self.iter().map(|x| *x * *x).sum::<T>().sqrt()
} }
fn norm(&self, p:T) -> T { fn norm(&self, p: T) -> T {
if p.is_infinite() && p.is_sign_positive() { if p.is_infinite() && p.is_sign_positive() {
self.iter().fold(T::neg_infinity(), |f, &val| { self.iter().fold(T::neg_infinity(), |f, &val| {
let v = val.abs(); let v = val.abs();
@@ -171,7 +170,7 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
} else { } else {
v v
} }
}) })
} else if p.is_infinite() && p.is_sign_negative() { } else if p.is_infinite() && p.is_sign_negative() {
self.iter().fold(T::infinity(), |f, &val| { self.iter().fold(T::infinity(), |f, &val| {
let v = val.abs(); let v = val.abs();
@@ -182,14 +181,13 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
} }
}) })
} else { } else {
let mut norm = T::zero(); let mut norm = T::zero();
for xi in self.iter() { for xi in self.iter() {
norm = norm + xi.abs().powf(p); norm = norm + xi.abs().powf(p);
} }
norm.powf(T::one()/p) norm.powf(T::one() / p)
} }
} }
@@ -197,46 +195,46 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
self.mean_axis(Axis(0)).unwrap().to_vec() self.mean_axis(Axis(0)).unwrap().to_vec()
} }
fn div_element_mut(&mut self, row: usize, col: usize, x: T){ fn div_element_mut(&mut self, row: usize, col: usize, x: T) {
self[[row, col]] = self[[row, col]] / x; self[[row, col]] = self[[row, col]] / x;
} }
fn mul_element_mut(&mut self, row: usize, col: usize, x: T){ fn mul_element_mut(&mut self, row: usize, col: usize, x: T) {
self[[row, col]] = self[[row, col]] * x; self[[row, col]] = self[[row, col]] * x;
} }
fn add_element_mut(&mut self, row: usize, col: usize, x: T){ fn add_element_mut(&mut self, row: usize, col: usize, x: T) {
self[[row, col]] = self[[row, col]] + x; self[[row, col]] = self[[row, col]] + x;
} }
fn sub_element_mut(&mut self, row: usize, col: usize, x: T){ fn sub_element_mut(&mut self, row: usize, col: usize, x: T) {
self[[row, col]] = self[[row, col]] - x; self[[row, col]] = self[[row, col]] - x;
} }
fn negative_mut(&mut self){ fn negative_mut(&mut self) {
*self *= -T::one(); *self *= -T::one();
} }
fn reshape(&self, nrows: usize, ncols: usize) -> Self{ fn reshape(&self, nrows: usize, ncols: usize) -> Self {
self.clone().into_shape((nrows, ncols)).unwrap() self.clone().into_shape((nrows, ncols)).unwrap()
} }
fn copy_from(&mut self, other: &Self){ fn copy_from(&mut self, other: &Self) {
self.assign(&other); self.assign(&other);
} }
fn abs_mut(&mut self) -> &Self{ fn abs_mut(&mut self) -> &Self {
for v in self.iter_mut(){ for v in self.iter_mut() {
*v = v.abs() *v = v.abs()
} }
self self
} }
fn sum(&self) -> T{ fn sum(&self) -> T {
self.sum() self.sum()
} }
fn max_diff(&self, other: &Self) -> T{ fn max_diff(&self, other: &Self) -> T {
let mut max_diff = T::zero(); let mut max_diff = T::zero();
for r in 0..self.nrows() { for r in 0..self.nrows() {
for c in 0..self.ncols() { for c in 0..self.ncols() {
@@ -245,9 +243,12 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
} }
max_diff max_diff
} }
fn softmax_mut(&mut self){ fn softmax_mut(&mut self) {
let max = self.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b)); let max = self
.iter()
.map(|x| x.abs())
.fold(T::neg_infinity(), |a, b| a.max(b));
let mut z = T::zero(); let mut z = T::zero();
for r in 0..self.nrows() { for r in 0..self.nrows() {
for c in 0..self.ncols() { for c in 0..self.ncols() {
@@ -263,7 +264,7 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
} }
} }
fn pow_mut(&mut self, p: T) -> &Self{ fn pow_mut(&mut self, p: T) -> &Self {
for r in 0..self.nrows() { for r in 0..self.nrows() {
for c in 0..self.ncols() { for c in 0..self.ncols() {
self.set(r, c, self[(r, c)].powf(p)); self.set(r, c, self[(r, c)].powf(p));
@@ -272,7 +273,7 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
self self
} }
fn argmax(&self) -> Vec<usize>{ fn argmax(&self) -> Vec<usize> {
let mut res = vec![0usize; self.nrows()]; let mut res = vec![0usize; self.nrows()];
for r in 0..self.nrows() { for r in 0..self.nrows() {
@@ -282,16 +283,15 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
let v = self[(r, c)]; let v = self[(r, c)];
if max < v { if max < v {
max = v; max = v;
max_pos = c; max_pos = c;
} }
} }
res[r] = max_pos; res[r] = max_pos;
} }
res res
} }
fn unique(&self) -> Vec<T> { fn unique(&self) -> Vec<T> {
let mut result = self.clone().into_raw_vec(); let mut result = self.clone().into_raw_vec();
result.sort_by(|a, b| a.partial_cmp(b).unwrap()); result.sort_by(|a, b| a.partial_cmp(b).unwrap());
@@ -302,231 +302,205 @@ impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign
fn cov(&self) -> Self { fn cov(&self) -> Self {
panic!("Not implemented"); panic!("Not implemented");
} }
} }
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {} impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
{
}
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {} impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
{
}
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {} impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
{
}
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> LUDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {} impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
LUDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
{
}
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {} impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T>
for ArrayBase<OwnedRepr<T>, Ix2>
{
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use ndarray::{arr1, arr2, Array2}; use ndarray::{arr1, arr2, Array2};
#[test] #[test]
fn vec_get_set() { fn vec_get_set() {
let mut result = arr1(&[1., 2., 3.]); let mut result = arr1(&[1., 2., 3.]);
let expected = arr1(&[1., 5., 3.]); let expected = arr1(&[1., 5., 3.]);
result.set(1, 5.); result.set(1, 5.);
assert_eq!(result, expected); assert_eq!(result, expected);
assert_eq!(5., BaseVector::get(&result, 1)); assert_eq!(5., BaseVector::get(&result, 1));
} }
#[test] #[test]
fn vec_len() { fn vec_len() {
let v = arr1(&[1., 2., 3.]); let v = arr1(&[1., 2., 3.]);
assert_eq!(3, v.len()); assert_eq!(3, v.len());
} }
#[test] #[test]
fn from_to_row_vec() { fn from_to_row_vec() {
let vec = arr1(&[1., 2., 3.]);
let vec = arr1(&[ 1., 2., 3.]);
assert_eq!(Array2::from_row_vector(vec.clone()), arr2(&[[1., 2., 3.]])); assert_eq!(Array2::from_row_vector(vec.clone()), arr2(&[[1., 2., 3.]]));
assert_eq!(Array2::from_row_vector(vec.clone()).to_row_vector(), arr1(&[1., 2., 3.])); assert_eq!(
Array2::from_row_vector(vec.clone()).to_row_vector(),
arr1(&[1., 2., 3.])
);
} }
#[test] #[test]
fn add_mut() { fn add_mut() {
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let mut a1 = arr2(&[[ 1., 2., 3.],
[4., 5., 6.]]);
let a2 = a1.clone(); let a2 = a1.clone();
let a3 = a1.clone() + a2.clone(); let a3 = a1.clone() + a2.clone();
a1.add_mut(&a2); a1.add_mut(&a2);
assert_eq!(a1, a3); assert_eq!(a1, a3);
} }
#[test] #[test]
fn sub_mut() { fn sub_mut() {
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let mut a1 = arr2(&[[ 1., 2., 3.],
[4., 5., 6.]]);
let a2 = a1.clone(); let a2 = a1.clone();
let a3 = a1.clone() - a2.clone(); let a3 = a1.clone() - a2.clone();
a1.sub_mut(&a2); a1.sub_mut(&a2);
assert_eq!(a1, a3); assert_eq!(a1, a3);
} }
#[test] #[test]
fn mul_mut() { fn mul_mut() {
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let mut a1 = arr2(&[[ 1., 2., 3.],
[4., 5., 6.]]);
let a2 = a1.clone(); let a2 = a1.clone();
let a3 = a1.clone() * a2.clone(); let a3 = a1.clone() * a2.clone();
a1.mul_mut(&a2); a1.mul_mut(&a2);
assert_eq!(a1, a3);
assert_eq!(a1, a3);
} }
#[test] #[test]
fn div_mut() { fn div_mut() {
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let mut a1 = arr2(&[[ 1., 2., 3.],
[4., 5., 6.]]);
let a2 = a1.clone(); let a2 = a1.clone();
let a3 = a1.clone() / a2.clone(); let a3 = a1.clone() / a2.clone();
a1.div_mut(&a2); a1.div_mut(&a2);
assert_eq!(a1, a3);
assert_eq!(a1, a3);
} }
#[test] #[test]
fn div_element_mut() { fn div_element_mut() {
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let mut a = arr2(&[[ 1., 2., 3.],
[4., 5., 6.]]);
a.div_element_mut(1, 1, 5.); a.div_element_mut(1, 1, 5.);
assert_eq!(BaseMatrix::get(&a, 1, 1), 1.); assert_eq!(BaseMatrix::get(&a, 1, 1), 1.);
} }
#[test] #[test]
fn mul_element_mut() { fn mul_element_mut() {
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let mut a = arr2(&[[ 1., 2., 3.],
[4., 5., 6.]]);
a.mul_element_mut(1, 1, 5.); a.mul_element_mut(1, 1, 5.);
assert_eq!(BaseMatrix::get(&a, 1, 1), 25.); assert_eq!(BaseMatrix::get(&a, 1, 1), 25.);
} }
#[test] #[test]
fn add_element_mut() { fn add_element_mut() {
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let mut a = arr2(&[[ 1., 2., 3.],
[4., 5., 6.]]);
a.add_element_mut(1, 1, 5.); a.add_element_mut(1, 1, 5.);
assert_eq!(BaseMatrix::get(&a, 1, 1), 10.); assert_eq!(BaseMatrix::get(&a, 1, 1), 10.);
} }
#[test] #[test]
fn sub_element_mut() { fn sub_element_mut() {
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let mut a = arr2(&[[ 1., 2., 3.],
[4., 5., 6.]]);
a.sub_element_mut(1, 1, 5.); a.sub_element_mut(1, 1, 5.);
assert_eq!(BaseMatrix::get(&a, 1, 1), 0.); assert_eq!(BaseMatrix::get(&a, 1, 1), 0.);
} }
#[test] #[test]
fn vstack_hstack() { fn vstack_hstack() {
let a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let a1 = arr2(&[[1., 2., 3.], let a2 = arr2(&[[7.], [8.]]);
[4., 5., 6.]]);
let a2 = arr2(&[[ 7.], [8.]]);
let a3 = arr2(&[[9., 10., 11., 12.]]); let a3 = arr2(&[[9., 10., 11., 12.]]);
let expected = arr2(&[[1., 2., 3., 7.], let expected = arr2(&[[1., 2., 3., 7.], [4., 5., 6., 8.], [9., 10., 11., 12.]]);
[4., 5., 6., 8.],
[9., 10., 11., 12.]]);
let result = a1.v_stack(&a2).h_stack(&a3); let result = a1.v_stack(&a2).h_stack(&a3);
assert_eq!(result, expected);
assert_eq!(result, expected);
} }
#[test] #[test]
fn to_raw_vector() { fn to_raw_vector() {
let result = arr2(&[[1., 2., 3.], [4., 5., 6.]]).to_raw_vector(); let result = arr2(&[[1., 2., 3.], [4., 5., 6.]]).to_raw_vector();
let expected = vec![1., 2., 3., 4., 5., 6.]; let expected = vec![1., 2., 3., 4., 5., 6.];
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[test] #[test]
fn get_set() { fn get_set() {
let mut result = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let mut result = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let expected = arr2(&[[1., 2., 3.], [4., 10., 6.]]); let expected = arr2(&[[1., 2., 3.], [4., 10., 6.]]);
result.set(1, 1, 10.); result.set(1, 1, 10.);
assert_eq!(result, expected); assert_eq!(result, expected);
assert_eq!(10., BaseMatrix::get(&result, 1, 1)); assert_eq!(10., BaseMatrix::get(&result, 1, 1));
}
#[test]
fn dot() {
let a = arr2(&[
[1., 2., 3.],
[4., 5., 6.]]);
let b = arr2(&[
[1., 2.],
[3., 4.],
[5., 6.]]);
let expected = arr2(&[
[22., 28.],
[49., 64.]]);
let result = BaseMatrix::dot(&a, &b);
assert_eq!(result, expected);
}
#[test]
fn vector_dot() {
let a = arr2(&[[1., 2., 3.]]);
let b = arr2(&[[1., 2., 3.]]);
assert_eq!(14., a.vector_dot(&b));
} }
#[test] #[test]
fn slice() { fn dot() {
let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
let a = arr2( let b = arr2(&[[1., 2.], [3., 4.], [5., 6.]]);
&[ let expected = arr2(&[[22., 28.], [49., 64.]]);
[1., 2., 3., 1., 2.], let result = BaseMatrix::dot(&a, &b);
[4., 5., 6., 3., 4.], assert_eq!(result, expected);
[7., 8., 9., 5., 6.]]);
let expected = arr2(
&[
[2., 3.],
[5., 6.]]);
let result = BaseMatrix::slice(&a, 0..2, 1..3);
assert_eq!(result, expected);
} }
#[test] #[test]
fn scalar_ops() { fn vector_dot() {
let a = arr2(&[[1., 2., 3.]]); let a = arr2(&[[1., 2., 3.]]);
assert_eq!(&arr2(&[[2., 3., 4.]]), a.clone().add_scalar_mut(1.)); let b = arr2(&[[1., 2., 3.]]);
assert_eq!(&arr2(&[[0., 1., 2.]]), a.clone().sub_scalar_mut(1.)); assert_eq!(14., a.vector_dot(&b));
assert_eq!(&arr2(&[[2., 4., 6.]]), a.clone().mul_scalar_mut(2.)); }
assert_eq!(&arr2(&[[0.5, 1., 1.5]]), a.clone().div_scalar_mut(2.));
#[test]
fn slice() {
let a = arr2(&[
[1., 2., 3., 1., 2.],
[4., 5., 6., 3., 4.],
[7., 8., 9., 5., 6.],
]);
let expected = arr2(&[[2., 3.], [5., 6.]]);
let result = BaseMatrix::slice(&a, 0..2, 1..3);
assert_eq!(result, expected);
}
#[test]
fn scalar_ops() {
let a = arr2(&[[1., 2., 3.]]);
assert_eq!(&arr2(&[[2., 3., 4.]]), a.clone().add_scalar_mut(1.));
assert_eq!(&arr2(&[[0., 1., 2.]]), a.clone().sub_scalar_mut(1.));
assert_eq!(&arr2(&[[2., 4., 6.]]), a.clone().mul_scalar_mut(2.));
assert_eq!(&arr2(&[[0.5, 1., 1.5]]), a.clone().div_scalar_mut(2.));
} }
#[test] #[test]
@@ -534,12 +508,12 @@ mod tests {
let m = arr2(&[[1.0, 3.0], [2.0, 4.0]]); let m = arr2(&[[1.0, 3.0], [2.0, 4.0]]);
let expected = arr2(&[[1.0, 2.0], [3.0, 4.0]]); let expected = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let m_transposed = m.transpose(); let m_transposed = m.transpose();
assert_eq!(m_transposed, expected); assert_eq!(m_transposed, expected);
} }
#[test] #[test]
fn norm() { fn norm() {
let v = arr2(&[[3., -2., 6.]]); let v = arr2(&[[3., -2., 6.]]);
assert_eq!(v.norm(1.), 11.); assert_eq!(v.norm(1.), 11.);
assert_eq!(v.norm(2.), 7.); assert_eq!(v.norm(2.), 7.);
assert_eq!(v.norm(std::f64::INFINITY), 6.); assert_eq!(v.norm(std::f64::INFINITY), 6.);
@@ -547,17 +521,17 @@ mod tests {
} }
#[test] #[test]
fn negative_mut() { fn negative_mut() {
let mut v = arr2(&[[3., -2., 6.]]); let mut v = arr2(&[[3., -2., 6.]]);
v.negative_mut(); v.negative_mut();
assert_eq!(v, arr2(&[[-3., 2., -6.]])); assert_eq!(v, arr2(&[[-3., 2., -6.]]));
} }
#[test] #[test]
fn reshape() { fn reshape() {
let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]); let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]);
let m_2_by_3 = BaseMatrix::reshape(&m_orig, 2, 3); let m_2_by_3 = BaseMatrix::reshape(&m_orig, 2, 3);
let m_result = BaseMatrix::reshape(&m_2_by_3, 1, 6); let m_result = BaseMatrix::reshape(&m_2_by_3, 1, 6);
assert_eq!(BaseMatrix::shape(&m_2_by_3), (2, 3)); assert_eq!(BaseMatrix::shape(&m_2_by_3), (2, 3));
assert_eq!(BaseMatrix::get(&m_2_by_3, 1, 1), 5.); assert_eq!(BaseMatrix::get(&m_2_by_3, 1, 1), 5.);
assert_eq!(BaseMatrix::get(&m_result, 0, 1), 2.); assert_eq!(BaseMatrix::get(&m_result, 0, 1), 2.);
@@ -567,84 +541,80 @@ mod tests {
#[test] #[test]
fn copy_from() { fn copy_from() {
let mut src = arr2(&[[1., 2., 3.]]); let mut src = arr2(&[[1., 2., 3.]]);
let dst = Array2::<f64>::zeros((1, 3)); let dst = Array2::<f64>::zeros((1, 3));
src.copy_from(&dst); src.copy_from(&dst);
assert_eq!(src, dst); assert_eq!(src, dst);
} }
#[test] #[test]
fn sum() { fn sum() {
let src = arr2(&[[1., 2., 3.]]); let src = arr2(&[[1., 2., 3.]]);
assert_eq!(src.sum(), 6.); assert_eq!(src.sum(), 6.);
} }
#[test] #[test]
fn max_diff() { fn max_diff() {
let a1 = arr2(&[[1., 2., 3.], [4., -5., 6.]]); let a1 = arr2(&[[1., 2., 3.], [4., -5., 6.]]);
let a2 = arr2(&[[2., 3., 4.], [1., 0., -12.]]); let a2 = arr2(&[[2., 3., 4.], [1., 0., -12.]]);
assert_eq!(a1.max_diff(&a2), 18.); assert_eq!(a1.max_diff(&a2), 18.);
assert_eq!(a2.max_diff(&a2), 0.); assert_eq!(a2.max_diff(&a2), 0.);
} }
#[test] #[test]
fn softmax_mut(){ fn softmax_mut() {
let mut prob: Array2<f64> = arr2(&[[1., 2., 3.]]); let mut prob: Array2<f64> = arr2(&[[1., 2., 3.]]);
prob.softmax_mut(); prob.softmax_mut();
assert!((BaseMatrix::get(&prob, 0, 0) - 0.09).abs() < 0.01); assert!((BaseMatrix::get(&prob, 0, 0) - 0.09).abs() < 0.01);
assert!((BaseMatrix::get(&prob, 0, 1) - 0.24).abs() < 0.01); assert!((BaseMatrix::get(&prob, 0, 1) - 0.24).abs() < 0.01);
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01); assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
} }
#[test] #[test]
fn pow_mut(){ fn pow_mut() {
let mut a = arr2(&[[1., 2., 3.]]); let mut a = arr2(&[[1., 2., 3.]]);
a.pow_mut(3.); a.pow_mut(3.);
assert_eq!(a, arr2(&[[1., 8., 27.]])); assert_eq!(a, arr2(&[[1., 8., 27.]]));
} }
#[test] #[test]
fn argmax(){ fn argmax() {
let a = arr2(&[[1., 2., 3.], [-5., -6., -7.], [0.1, 0.2, 0.1]]); let a = arr2(&[[1., 2., 3.], [-5., -6., -7.], [0.1, 0.2, 0.1]]);
let res = a.argmax(); let res = a.argmax();
assert_eq!(res, vec![2, 0, 1]); assert_eq!(res, vec![2, 0, 1]);
} }
#[test] #[test]
fn unique(){ fn unique() {
let a = arr2(&[[1., 2., 2.], [-2., -6., -7.], [2., 3., 4.]]); let a = arr2(&[[1., 2., 2.], [-2., -6., -7.], [2., 3., 4.]]);
let res = a.unique(); let res = a.unique();
assert_eq!(res.len(), 7); assert_eq!(res.len(), 7);
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]); assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
} }
#[test] #[test]
fn get_row_as_vector(){ fn get_row_as_vector() {
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
let res = a.get_row_as_vec(1); let res = a.get_row_as_vec(1);
assert_eq!(res, vec![4., 5., 6.]); assert_eq!(res, vec![4., 5., 6.]);
} }
#[test] #[test]
fn get_col_as_vector(){ fn get_col_as_vector() {
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
let res = a.get_col_as_vec(1); let res = a.get_col_as_vec(1);
assert_eq!(res, vec![2., 5., 8.]); assert_eq!(res, vec![2., 5., 8.]);
} }
#[test] #[test]
fn col_mean(){ fn col_mean() {
let a = arr2(&[[1., 2., 3.], let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
[4., 5., 6.],
[7., 8., 9.]]);
let res = a.column_mean(); let res = a.column_mean();
assert_eq!(res, vec![4., 5., 6.]); assert_eq!(res, vec![4., 5., 6.]);
} }
#[test] #[test]
fn eye(){ fn eye() {
let a = arr2(&[[1., 0., 0.], let a = arr2(&[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]);
[0., 1., 0.],
[0., 0., 1.]]);
let res: Array2<f64> = BaseMatrix::eye(3); let res: Array2<f64> = BaseMatrix::eye(3);
assert_eq!(res, a); assert_eq!(res, a);
} }
@@ -661,12 +631,8 @@ mod tests {
#[test] #[test]
fn approximate_eq() { fn approximate_eq() {
let a = arr2(&[[1., 2., 3.], let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
[4., 5., 6.], let noise = arr2(&[[1e-5, 2e-5, 3e-5], [4e-5, 5e-5, 6e-5], [7e-5, 8e-5, 9e-5]]);
[7., 8., 9.]]);
let noise = arr2(&[[1e-5, 2e-5, 3e-5],
[4e-5, 5e-5, 6e-5],
[7e-5, 8e-5, 9e-5]]);
assert!(a.approximate_eq(&(&noise + &a), 1e-4)); assert!(a.approximate_eq(&(&noise + &a), 1e-4));
assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
} }
@@ -678,4 +644,4 @@ mod tests {
a.abs_mut(); a.abs_mut();
assert_eq!(a, expected); assert_eq!(a, expected);
} }
} }
+49 -55
View File
@@ -2,19 +2,18 @@
use std::fmt::Debug; use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::math::num::FloatExt;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct QR<T: FloatExt, M: BaseMatrix<T>> { pub struct QR<T: FloatExt, M: BaseMatrix<T>> {
QR: M, QR: M,
tau: Vec<T>, tau: Vec<T>,
singular: bool singular: bool,
} }
impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> { impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
pub fn new(QR: M, tau: Vec<T>) -> QR<T, M> { pub fn new(QR: M, tau: Vec<T>) -> QR<T, M> {
let mut singular = false; let mut singular = false;
for j in 0..tau.len() { for j in 0..tau.len() {
if tau[j] == T::zero() { if tau[j] == T::zero() {
@@ -26,7 +25,7 @@ impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
QR { QR {
QR: QR, QR: QR,
tau: tau, tau: tau,
singular: singular singular: singular,
} }
} }
@@ -35,13 +34,13 @@ impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
let mut R = M::zeros(n, n); let mut R = M::zeros(n, n);
for i in 0..n { for i in 0..n {
R.set(i, i, self.tau[i]); R.set(i, i, self.tau[i]);
for j in i+1..n { for j in i + 1..n {
R.set(i, j, self.QR.get(i, j)); R.set(i, j, self.QR.get(i, j));
} }
} }
return R; return R;
} }
pub fn Q(&self) -> M { pub fn Q(&self) -> M {
let (m, n) = self.QR.shape(); let (m, n) = self.QR.shape();
let mut Q = M::zeros(m, n); let mut Q = M::zeros(m, n);
@@ -63,19 +62,21 @@ impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
if k == 0 { if k == 0 {
break; break;
} else { } else {
k -= 1; k -= 1;
} }
} }
return Q; return Q;
} }
fn solve(&self, mut b: M) -> M { fn solve(&self, mut b: M) -> M {
let (m, n) = self.QR.shape(); let (m, n) = self.QR.shape();
let (b_nrows, b_ncols) = b.shape(); let (b_nrows, b_ncols) = b.shape();
if b_nrows != m { if b_nrows != m {
panic!("Row dimensions do not agree: A is {} x {}, but B is {} x {}", m, n, b_nrows, b_ncols); panic!(
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
m, n, b_nrows, b_ncols
);
} }
if self.singular { if self.singular {
@@ -93,13 +94,13 @@ impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
b.add_element_mut(i, j, s * self.QR.get(i, k)); b.add_element_mut(i, j, s * self.QR.get(i, k));
} }
} }
} }
for k in (0..n).rev() { for k in (0..n).rev() {
for j in 0..b_ncols { for j in 0..b_ncols {
b.set(k, j, b.get(k, j) / self.tau[k]); b.set(k, j, b.get(k, j) / self.tau[k]);
} }
for i in 0..k { for i in 0..k {
for j in 0..b_ncols { for j in 0..b_ncols {
b.sub_element_mut(i, j, b.get(k, j) * self.QR.get(i, k)); b.sub_element_mut(i, j, b.get(k, j) * self.QR.get(i, k));
@@ -108,19 +109,16 @@ impl<T: FloatExt, M: BaseMatrix<T>> QR<T, M> {
} }
b b
} }
} }
pub trait QRDecomposableMatrix<T: FloatExt>: BaseMatrix<T> { pub trait QRDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
fn qr(&self) -> QR<T, Self> { fn qr(&self) -> QR<T, Self> {
self.clone().qr_mut() self.clone().qr_mut()
} }
fn qr_mut(mut self) -> QR<T, Self> { fn qr_mut(mut self) -> QR<T, Self> {
let (m, n) = self.shape();
let (m, n) = self.shape();
let mut r_diagonal: Vec<T> = vec![T::zero(); n]; let mut r_diagonal: Vec<T> = vec![T::zero(); n];
@@ -131,7 +129,6 @@ pub trait QRDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
} }
if nrm.abs() > T::epsilon() { if nrm.abs() > T::epsilon() {
if self.get(k, k) < T::zero() { if self.get(k, k) < T::zero() {
nrm = -nrm; nrm = -nrm;
} }
@@ -140,7 +137,7 @@ pub trait QRDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
} }
self.add_element_mut(k, k, T::one()); self.add_element_mut(k, k, T::one());
for j in k+1..n { for j in k + 1..n {
let mut s = T::zero(); let mut s = T::zero();
for i in k..m { for i in k..m {
s = s + self.get(i, k) * self.get(i, j); s = s + self.get(i, k) * self.get(i, j);
@@ -150,54 +147,51 @@ pub trait QRDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
self.add_element_mut(i, j, s * self.get(i, k)); self.add_element_mut(i, j, s * self.get(i, k));
} }
} }
} }
r_diagonal[k] = -nrm; r_diagonal[k] = -nrm;
} }
QR::new(self, r_diagonal) QR::new(self, r_diagonal)
} }
fn qr_solve_mut(self, b: Self) -> Self { fn qr_solve_mut(self, b: Self) -> Self {
self.qr_mut().solve(b)
self.qr_mut().solve(b) }
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[test] #[test]
fn decompose() { fn decompose() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let q = DenseMatrix::from_array(&[
let q = DenseMatrix::from_array(&[ &[-0.7448, 0.2436, 0.6212],
&[-0.7448, 0.2436, 0.6212], &[-0.331, -0.9432, -0.027],
&[-0.331, -0.9432, -0.027], &[-0.5793, 0.2257, -0.7832],
&[-0.5793, 0.2257, -0.7832]]); ]);
let r = DenseMatrix::from_array(&[ let r = DenseMatrix::from_array(&[
&[-1.2083, -0.6373, -1.0842], &[-1.2083, -0.6373, -1.0842],
&[0.0, -0.3064, 0.0682], &[0.0, -0.3064, 0.0682],
&[0.0, 0.0, -0.1999]]); &[0.0, 0.0, -0.1999],
let qr = a.qr(); ]);
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4)); let qr = a.qr();
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4)); assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
} }
#[test] #[test]
fn qr_solve_mut() { fn qr_solve_mut() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let b = DenseMatrix::from_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]); let expected_w = DenseMatrix::from_array(&[
let expected_w = DenseMatrix::from_array(&[ &[-0.2027027, -1.2837838],
&[-0.2027027, -1.2837838], &[0.8783784, 2.2297297],
&[0.8783784, 2.2297297], &[0.4729730, 0.6621622],
&[0.4729730, 0.6621622] ]);
]); let w = a.qr_solve_mut(b);
let w = a.qr_solve_mut(b); assert!(w.approximate_eq(&expected_w, 1e-2));
assert!(w.approximate_eq(&expected_w, 1e-2));
} }
} }
+221 -74
View File
@@ -12,17 +12,16 @@ pub struct SVD<T: FloatExt, M: SVDDecomposableMatrix<T>> {
full: bool, full: bool,
m: usize, m: usize,
n: usize, n: usize,
tol: T tol: T,
} }
pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> { pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
fn svd_solve_mut(self, b: Self) -> Self { fn svd_solve_mut(self, b: Self) -> Self {
self.svd_mut().solve(b) self.svd_mut().solve(b)
} }
fn svd_solve(&self, b: Self) -> Self { fn svd_solve(&self, b: Self) -> Self {
self.svd().solve(b) self.svd().solve(b)
} }
fn svd(&self) -> SVD<T, Self> { fn svd(&self) -> SVD<T, Self> {
@@ -30,14 +29,13 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
} }
fn svd_mut(self) -> SVD<T, Self> { fn svd_mut(self) -> SVD<T, Self> {
let mut U = self;
let mut U = self; let (m, n) = U.shape();
let (m, n) = U.shape();
let (mut l, mut nm) = (0usize, 0usize); let (mut l, mut nm) = (0usize, 0usize);
let (mut anorm, mut g, mut scale) = (T::zero(), T::zero(), T::zero()); let (mut anorm, mut g, mut scale) = (T::zero(), T::zero(), T::zero());
let mut v = Self::zeros(n, n); let mut v = Self::zeros(n, n);
let mut w = vec![T::zero(); n]; let mut w = vec![T::zero(); n];
let mut rv1 = vec![T::zero(); n]; let mut rv1 = vec![T::zero(); n];
@@ -55,7 +53,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
} }
if scale.abs() > T::epsilon() { if scale.abs() > T::epsilon() {
for k in i..m { for k in i..m {
U.div_element_mut(k, i, scale); U.div_element_mut(k, i, scale);
s = s + U.get(k, i) * U.get(k, i); s = s + U.get(k, i) * U.get(k, i);
@@ -98,7 +95,7 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
} }
let f = U.get(i, l - 1); let f = U.get(i, l - 1);
g = -s.sqrt().copysign(f); g = -s.sqrt().copysign(f);
let h = f * g - s; let h = f * g - s;
U.set(i, l - 1, f - g); U.set(i, l - 1, f - g);
@@ -123,7 +120,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
} }
} }
anorm = T::max(anorm, w[i].abs() + rv1[i].abs()); anorm = T::max(anorm, w[i].abs() + rv1[i].abs());
} }
@@ -186,11 +182,11 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
for k in (0..n).rev() { for k in (0..n).rev() {
for iteration in 0..30 { for iteration in 0..30 {
let mut flag = true; let mut flag = true;
l = k; l = k;
while l != 0 { while l != 0 {
if l == 0 || rv1[l].abs() <= T::epsilon() * anorm { if l == 0 || rv1[l].abs() <= T::epsilon() * anorm {
flag = false; flag = false;
break; break;
} }
nm = l - 1; nm = l - 1;
@@ -203,7 +199,7 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
if flag { if flag {
let mut c = T::zero(); let mut c = T::zero();
let mut s = T::one(); let mut s = T::one();
for i in l..k+1 { for i in l..k + 1 {
let f = s * rv1[i]; let f = s * rv1[i];
rv1[i] = c * rv1[i]; rv1[i] = c * rv1[i];
if f.abs() <= T::epsilon() * anorm { if f.abs() <= T::epsilon() * anorm {
@@ -219,7 +215,7 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
let y = U.get(j, nm); let y = U.get(j, nm);
let z = U.get(j, i); let z = U.get(j, i);
U.set(j, nm, y * c + z * s); U.set(j, nm, y * c + z * s);
U.set(j, i, z * c - y * s); U.set(j, i, z * c - y * s);
} }
} }
} }
@@ -295,11 +291,11 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
w[k] = x; w[k] = x;
} }
} }
let mut inc = 1usize; let mut inc = 1usize;
let mut su = vec![T::zero(); m]; let mut su = vec![T::zero(); m];
let mut sv = vec![T::zero(); n]; let mut sv = vec![T::zero(); n];
loop { loop {
inc *= 3; inc *= 3;
inc += 1; inc += 1;
@@ -310,7 +306,7 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
loop { loop {
inc /= 3; inc /= 3;
for i in inc..n { for i in inc..n {
let sw = w[i]; let sw = w[i];
for k in 0..m { for k in 0..m {
su[k] = U.get(k, i); su[k] = U.get(k, i);
@@ -339,7 +335,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
for k in 0..n { for k in 0..n {
v.set(k, j, sv[k]); v.set(k, j, sv[k]);
} }
} }
if inc <= 1 { if inc <= 1 {
break; break;
@@ -366,18 +361,17 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
v.set(j, k, -v.get(j, k)); v.set(j, k, -v.get(j, k));
} }
} }
} }
SVD::new(U, v, w) SVD::new(U, v, w)
} }
} }
impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> { impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> {
pub fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> { pub fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
let m = U.shape().0; let m = U.shape().0;
let n = V.shape().0; let n = V.shape().0;
let full = s.len() == m.min(n); let full = s.len() == m.min(n);
let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon(); let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
SVD { SVD {
U: U, U: U,
@@ -386,7 +380,7 @@ impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> {
full: full, full: full,
m: m, m: m,
n: n, n: n,
tol: tol tol: tol,
} }
} }
@@ -394,7 +388,11 @@ impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> {
let p = b.shape().1; let p = b.shape().1;
if self.U.shape().0 != b.shape().0 { if self.U.shape().0 != b.shape().0 {
panic!("Dimensions do not agree. U.nrows should equal b.nrows but is {}, {}", self.U.shape().0, b.shape().0); panic!(
"Dimensions do not agree. U.nrows should equal b.nrows but is {}, {}",
self.U.shape().0,
b.shape().0
);
} }
for k in 0..p { for k in 0..p {
@@ -424,30 +422,30 @@ impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> {
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test] #[test]
fn decompose_symmetric() { fn decompose_symmetric() {
let A = DenseMatrix::from_array(&[ let A = DenseMatrix::from_array(&[
&[0.9000, 0.4000, 0.7000], &[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000], &[0.4000, 0.5000, 0.3000],
&[0.7000, 0.3000, 0.8000]]); &[0.7000, 0.3000, 0.8000],
]);
let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834]; let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
let U = DenseMatrix::from_array(&[ let U = DenseMatrix::from_array(&[
&[0.6881997, -0.07121225, 0.7220180], &[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886], &[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.639158] &[0.6240573, -0.44947578, -0.639158],
]); ]);
let V = DenseMatrix::from_array(&[ let V = DenseMatrix::from_array(&[
&[0.6881997, -0.07121225, 0.7220180], &[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886], &[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.6391588] &[0.6240573, -0.44947578, -0.6391588],
]); ]);
let svd = A.svd(); let svd = A.svd();
@@ -457,43 +455,199 @@ mod tests {
for i in 0..s.len() { for i in 0..s.len() {
assert!((s[i] - svd.s[i]).abs() < 1e-4); assert!((s[i] - svd.s[i]).abs() < 1e-4);
} }
} }
#[test] #[test]
fn decompose_asymmetric() { fn decompose_asymmetric() {
let A = DenseMatrix::from_array(&[ let A = DenseMatrix::from_array(&[
&[1.19720880, -1.8391378, 0.3019585, -1.1165701, -1.7210814, 0.4918882, -0.04247433], &[
&[0.06605075, 1.0315583, 0.8294362, -0.3646043, -1.6038017, -0.9188110, -0.63760340], 1.19720880,
&[-1.02637715, 1.0747931, -0.8089055, -0.4726863, -0.2064826, -0.3325532, 0.17966051], -1.8391378,
&[-1.45817729, -0.8942353, 0.3459245, 1.5068363, -2.0180708, -0.3696350, -1.19575563], 0.3019585,
&[-0.07318103, -0.2783787, 1.2237598, 0.1995332, 0.2545336, -0.1392502, -1.88207227], -1.1165701,
&[0.88248425, -0.9360321, 0.1393172, 0.1393281, -0.3277873, -0.5553013, 1.63805985], -1.7210814,
&[0.12641406, -0.8710055, -0.2712301, 0.2296515, 1.1781535, -0.2158704, -0.27529472] 0.4918882,
-0.04247433,
],
&[
0.06605075,
1.0315583,
0.8294362,
-0.3646043,
-1.6038017,
-0.9188110,
-0.63760340,
],
&[
-1.02637715,
1.0747931,
-0.8089055,
-0.4726863,
-0.2064826,
-0.3325532,
0.17966051,
],
&[
-1.45817729,
-0.8942353,
0.3459245,
1.5068363,
-2.0180708,
-0.3696350,
-1.19575563,
],
&[
-0.07318103,
-0.2783787,
1.2237598,
0.1995332,
0.2545336,
-0.1392502,
-1.88207227,
],
&[
0.88248425, -0.9360321, 0.1393172, 0.1393281, -0.3277873, -0.5553013, 1.63805985,
],
&[
0.12641406,
-0.8710055,
-0.2712301,
0.2296515,
1.1781535,
-0.2158704,
-0.27529472,
],
]); ]);
let s: Vec<f64> = vec![3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515]; let s: Vec<f64> = vec![
3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515,
];
let U = DenseMatrix::from_array(&[ let U = DenseMatrix::from_array(&[
&[-0.3082776, 0.77676231, 0.01330514, 0.23231424, -0.47682758, 0.13927109, 0.02640713], &[
&[-0.4013477, -0.09112050, 0.48754440, 0.47371793, 0.40636608, 0.24600706, -0.37796295], -0.3082776,
&[0.0599719, -0.31406586, 0.45428229, -0.08071283, -0.38432597, 0.57320261, 0.45673993], 0.77676231,
&[-0.7694214, -0.12681435, -0.05536793, -0.62189972, -0.02075522, -0.01724911, -0.03681864], 0.01330514,
&[-0.3319069, -0.17984404, -0.54466777, 0.45335157, 0.19377726, 0.12333423, 0.55003852], 0.23231424,
&[0.1259351, 0.49087824, 0.16349687, -0.32080176, 0.64828744, 0.20643772, 0.38812467], -0.47682758,
&[0.1491884, 0.01768604, -0.47884363, -0.14108924, 0.03922507, 0.73034065, -0.43965505] 0.13927109,
0.02640713,
],
&[
-0.4013477,
-0.09112050,
0.48754440,
0.47371793,
0.40636608,
0.24600706,
-0.37796295,
],
&[
0.0599719,
-0.31406586,
0.45428229,
-0.08071283,
-0.38432597,
0.57320261,
0.45673993,
],
&[
-0.7694214,
-0.12681435,
-0.05536793,
-0.62189972,
-0.02075522,
-0.01724911,
-0.03681864,
],
&[
-0.3319069,
-0.17984404,
-0.54466777,
0.45335157,
0.19377726,
0.12333423,
0.55003852,
],
&[
0.1259351,
0.49087824,
0.16349687,
-0.32080176,
0.64828744,
0.20643772,
0.38812467,
],
&[
0.1491884,
0.01768604,
-0.47884363,
-0.14108924,
0.03922507,
0.73034065,
-0.43965505,
],
]); ]);
let V = DenseMatrix::from_array(&[ let V = DenseMatrix::from_array(&[
&[-0.2122609, -0.54650056, 0.08071332, -0.43239135, -0.2925067, 0.1414550, 0.59769207], &[
&[-0.1943605, 0.63132116, -0.54059857, -0.37089970, -0.1363031, 0.2892641, 0.17774114], -0.2122609,
&[0.3031265, -0.06182488, 0.18579097, -0.38606409, -0.5364911, 0.2983466, -0.58642548], -0.54650056,
&[0.1844063, 0.24425278, 0.25923756, 0.59043765, -0.4435443, 0.3959057, 0.37019098], 0.08071332,
&[-0.7164205, 0.30694911, 0.58264743, -0.07458095, -0.1142140, -0.1311972, -0.13124764], -0.43239135,
&[-0.1103067, -0.10633600, 0.18257905, -0.03638501, 0.5722925, 0.7784398, -0.09153611], -0.2925067,
&[-0.5156083, -0.36573746, -0.47613340, 0.41342817, -0.2659765, 0.1654796, -0.32346758] 0.1414550,
]); 0.59769207,
],
&[
-0.1943605,
0.63132116,
-0.54059857,
-0.37089970,
-0.1363031,
0.2892641,
0.17774114,
],
&[
0.3031265,
-0.06182488,
0.18579097,
-0.38606409,
-0.5364911,
0.2983466,
-0.58642548,
],
&[
0.1844063, 0.24425278, 0.25923756, 0.59043765, -0.4435443, 0.3959057, 0.37019098,
],
&[
-0.7164205,
0.30694911,
0.58264743,
-0.07458095,
-0.1142140,
-0.1311972,
-0.13124764,
],
&[
-0.1103067,
-0.10633600,
0.18257905,
-0.03638501,
0.5722925,
0.7784398,
-0.09153611,
],
&[
-0.5156083,
-0.36573746,
-0.47613340,
0.41342817,
-0.2659765,
0.1654796,
-0.32346758,
],
]);
let svd = A.svd(); let svd = A.svd();
@@ -502,21 +656,14 @@ mod tests {
for i in 0..s.len() { for i in 0..s.len() {
assert!((s[i] - svd.s[i]).abs() < 1e-4); assert!((s[i] - svd.s[i]).abs() < 1e-4);
} }
} }
#[test] #[test]
fn solve() { fn solve() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let b = DenseMatrix::from_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]); let expected_w = DenseMatrix::from_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
let expected_w = DenseMatrix::from_array(&[ let w = a.svd_solve_mut(b);
&[-0.20, -1.28], assert!(w.approximate_eq(&expected_w, 1e-2));
&[0.87, 2.22],
&[0.47, 0.66]
]);
let w = a.svd_solve_mut(b);
assert!(w.approximate_eq(&expected_w, 1e-2));
} }
}
}
+98 -82
View File
@@ -1,34 +1,32 @@
use std::fmt::Debug; use std::fmt::Debug;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::FloatExt;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub enum LinearRegressionSolver { pub enum LinearRegressionSolver {
QR, QR,
SVD SVD,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct LinearRegression<T: FloatExt, M: Matrix<T>> { pub struct LinearRegression<T: FloatExt, M: Matrix<T>> {
coefficients: M, coefficients: M,
intercept: T, intercept: T,
solver: LinearRegressionSolver solver: LinearRegressionSolver,
} }
impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> { impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients && self.coefficients == other.coefficients
(self.intercept - other.intercept).abs() <= T::epsilon() && (self.intercept - other.intercept).abs() <= T::epsilon()
} }
} }
impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> { impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
pub fn fit(x: &M, y: &M::RowVector, solver: LinearRegressionSolver) -> LinearRegression<T, M> {
pub fn fit(x: &M, y: &M::RowVector, solver: LinearRegressionSolver) -> LinearRegression<T, M>{
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let b = y_m.transpose(); let b = y_m.transpose();
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
@@ -37,20 +35,20 @@ impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
if x_nrows != y_nrows { if x_nrows != y_nrows {
panic!("Number of rows of X doesn't match number of rows of Y"); panic!("Number of rows of X doesn't match number of rows of Y");
} }
let a = x.v_stack(&M::ones(x_nrows, 1)); let a = x.v_stack(&M::ones(x_nrows, 1));
let w = match solver { let w = match solver {
LinearRegressionSolver::QR => a.qr_solve_mut(b), LinearRegressionSolver::QR => a.qr_solve_mut(b),
LinearRegressionSolver::SVD => a.svd_solve_mut(b) LinearRegressionSolver::SVD => a.svd_solve_mut(b),
}; };
let wights = w.slice(0..num_attributes, 0..1); let wights = w.slice(0..num_attributes, 0..1);
LinearRegression { LinearRegression {
intercept: w.get(num_attributes, 0), intercept: w.get(num_attributes, 0),
coefficients: wights, coefficients: wights,
solver: solver solver: solver,
} }
} }
@@ -60,81 +58,54 @@ impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
y_hat.add_mut(&M::fill(nrows, 1, self.intercept)); y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
y_hat.transpose().to_row_vector() y_hat.transpose().to_row_vector()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*;
use nalgebra::{DMatrix, RowDVector}; use nalgebra::{DMatrix, RowDVector};
use crate::linalg::naive::dense_matrix::*;
#[test] #[test]
fn ols_fit_predict() { fn ols_fit_predict() {
let x = DMatrix::from_row_slice(
16,
6,
&[
234.289, 235.6, 159.0, 107.608, 1947., 60.323, 259.426, 232.5, 145.6, 108.632,
1948., 61.122, 258.054, 368.2, 161.6, 109.773, 1949., 60.171, 284.599, 335.1,
165.0, 110.929, 1950., 61.187, 328.975, 209.9, 309.9, 112.075, 1951., 63.221,
346.999, 193.2, 359.4, 113.270, 1952., 63.639, 365.385, 187.0, 354.7, 115.094,
1953., 64.989, 363.112, 357.8, 335.0, 116.219, 1954., 63.761, 397.469, 290.4,
304.8, 117.388, 1955., 66.019, 419.180, 282.2, 285.7, 118.734, 1956., 67.857,
442.769, 293.6, 279.8, 120.445, 1957., 68.169, 444.546, 468.1, 263.7, 121.950,
1958., 66.513, 482.704, 381.3, 255.2, 123.366, 1959., 68.655, 502.601, 393.1,
251.4, 125.368, 1960., 69.564, 518.173, 480.6, 257.2, 127.852, 1961., 69.331,
554.894, 400.7, 282.7, 130.081, 1962., 70.551,
],
);
let x = DMatrix::from_row_slice(16, 6, &[ let y: RowDVector<f64> = RowDVector::from_vec(vec![
234.289, 235.6, 159.0, 107.608, 1947., 60.323, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
259.426, 232.5, 145.6, 108.632, 1948., 61.122, 114.2, 115.7, 116.9,
258.054, 368.2, 161.6, 109.773, 1949., 60.171, ]);
284.599, 335.1, 165.0, 110.929, 1950., 61.187,
328.975, 209.9, 309.9, 112.075, 1951., 63.221,
346.999, 193.2, 359.4, 113.270, 1952., 63.639,
365.385, 187.0, 354.7, 115.094, 1953., 64.989,
363.112, 357.8, 335.0, 116.219, 1954., 63.761,
397.469, 290.4, 304.8, 117.388, 1955., 66.019,
419.180, 282.2, 285.7, 118.734, 1956., 67.857,
442.769, 293.6, 279.8, 120.445, 1957., 68.169,
444.546, 468.1, 263.7, 121.950, 1958., 66.513,
482.704, 381.3, 255.2, 123.366, 1959., 68.655,
502.601, 393.1, 251.4, 125.368, 1960., 69.564,
518.173, 480.6, 257.2, 127.852, 1961., 69.331,
554.894, 400.7, 282.7, 130.081, 1962., 70.551]);
let y: RowDVector<f64> = RowDVector::from_vec(vec!(83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9));
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x); let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x); let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
assert!(y.iter().zip(y_hat_qr.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0));
assert!(y.iter().zip(y_hat_svd.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0));
assert!(y
.iter()
.zip(y_hat_qr.iter())
.all(|(&a, &b)| (a - b).abs() <= 5.0));
assert!(y
.iter()
.zip(y_hat_svd.iter())
.all(|(&a, &b)| (a - b).abs() <= 5.0));
} }
#[test] #[test]
fn ols_fit_predict_nalgebra() { fn ols_fit_predict_nalgebra() {
let x = DenseMatrix::from_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
let y: Vec<f64> = vec!(83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9);
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
assert!(y.iter().zip(y_hat_qr.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0));
assert!(y.iter().zip(y_hat_svd.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0));
}
#[test]
fn serde(){
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323], &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122], &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
@@ -151,14 +122,59 @@ mod tests {
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655], &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
let y = vec!(83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9);
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
assert!(y
.iter()
.zip(y_hat_qr.iter())
.all(|(&a, &b)| (a - b).abs() <= 5.0));
assert!(y
.iter()
.zip(y_hat_svd.iter())
.all(|(&a, &b)| (a - b).abs() <= 5.0));
}
#[test]
fn serde() {
let x = DenseMatrix::from_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR); let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR);
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr); let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
} }
} }
+234 -225
View File
@@ -1,21 +1,21 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::optimization::FunctionOrder; use crate::math::num::FloatExt;
use crate::optimization::first_order::lbfgs::LBFGS;
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult}; use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
use crate::optimization::line_search::Backtracking; use crate::optimization::line_search::Backtracking;
use crate::optimization::first_order::lbfgs::LBFGS; use crate::optimization::FunctionOrder;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct LogisticRegression<T: FloatExt, M: Matrix<T>> { pub struct LogisticRegression<T: FloatExt, M: Matrix<T>> {
weights: M, weights: M,
classes: Vec<T>, classes: Vec<T>,
num_attributes: usize, num_attributes: usize,
num_classes: usize num_classes: usize,
} }
trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> { trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
@@ -24,11 +24,11 @@ trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
fn partial_dot(w: &M, x: &M, v_col: usize, m_row: usize) -> T { fn partial_dot(w: &M, x: &M, v_col: usize, m_row: usize) -> T {
let mut sum = T::zero(); let mut sum = T::zero();
let p = x.shape().1; let p = x.shape().1;
for i in 0..p { for i in 0..p {
sum = sum + x.get(m_row, i) * w.get(0, i + v_col); sum = sum + x.get(m_row, i) * w.get(0, i + v_col);
} }
sum + w.get(0, p + v_col) sum + w.get(0, p + v_col)
} }
} }
@@ -36,121 +36,119 @@ trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> { struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
x: &'a M, x: &'a M,
y: Vec<usize>, y: Vec<usize>,
phantom: PhantomData<&'a T> phantom: PhantomData<&'a T>,
} }
impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> { impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.num_classes != other.num_classes
if self.num_classes != other.num_classes || || self.num_attributes != other.num_attributes
self.num_attributes != other.num_attributes || || self.classes.len() != other.classes.len()
self.classes.len() != other.classes.len() { {
return false return false;
} else { } else {
for i in 0..self.classes.len() { for i in 0..self.classes.len() {
if (self.classes[i] - other.classes[i]).abs() > T::epsilon(){ if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
return false return false;
} }
} }
return self.weights == other.weights return self.weights == other.weights;
} }
} }
} }
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> { impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
fn f(&self, w_bias: &M) -> T {
let mut f = T::zero();
let (n, _) = self.x.shape();
fn f(&self, w_bias: &M) -> T {
let mut f = T::zero();
let (n, _) = self.x.shape();
for i in 0..n { for i in 0..n {
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i); let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
f = f + (wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx); f = f + (wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx);
} }
f f
} }
fn df(&self, g: &mut M, w_bias: &M) { fn df(&self, g: &mut M, w_bias: &M) {
g.copy_from(&M::zeros(1, g.shape().1)); g.copy_from(&M::zeros(1, g.shape().1));
let (n, p) = self.x.shape(); let (n, p) = self.x.shape();
for i in 0..n { for i in 0..n {
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
let dyi = (T::from(self.y[i]).unwrap()) - wx.sigmoid(); let dyi = (T::from(self.y[i]).unwrap()) - wx.sigmoid();
for j in 0..p { for j in 0..p {
g.set(0, j, g.get(0, j) - dyi * self.x.get(i, j)); g.set(0, j, g.get(0, j) - dyi * self.x.get(i, j));
} }
g.set(0, p, g.get(0, p) - dyi); g.set(0, p, g.get(0, p) - dyi);
} }
}
}
} }
struct MultiClassObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> { struct MultiClassObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
x: &'a M, x: &'a M,
y: Vec<usize>, y: Vec<usize>,
k: usize, k: usize,
phantom: PhantomData<&'a T> phantom: PhantomData<&'a T>,
} }
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObjectiveFunction<'a, T, M> { impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M>
for MultiClassObjectiveFunction<'a, T, M>
fn f(&self, w_bias: &M) -> T { {
fn f(&self, w_bias: &M) -> T {
let mut f = T::zero(); let mut f = T::zero();
let mut prob = M::zeros(1, self.k); let mut prob = M::zeros(1, self.k);
let (n, p) = self.x.shape(); let (n, p) = self.x.shape();
for i in 0..n { for i in 0..n {
for j in 0..self.k { for j in 0..self.k {
prob.set(0, j, MultiClassObjectiveFunction::partial_dot(w_bias, self.x, j * (p + 1), i)); prob.set(
0,
j,
MultiClassObjectiveFunction::partial_dot(w_bias, self.x, j * (p + 1), i),
);
} }
prob.softmax_mut(); prob.softmax_mut();
f = f - prob.get(0, self.y[i]).ln(); f = f - prob.get(0, self.y[i]).ln();
} }
f f
} }
fn df(&self, g: &mut M, w: &M) { fn df(&self, g: &mut M, w: &M) {
g.copy_from(&M::zeros(1, g.shape().1)); g.copy_from(&M::zeros(1, g.shape().1));
let mut prob = M::zeros(1, self.k);
let (n, p) = self.x.shape();
for i in 0..n {
for j in 0..self.k {
prob.set(0, j, MultiClassObjectiveFunction::partial_dot(w, self.x, j * (p + 1), i));
}
prob.softmax_mut(); let mut prob = M::zeros(1, self.k);
let (n, p) = self.x.shape();
for i in 0..n {
for j in 0..self.k {
prob.set(
0,
j,
MultiClassObjectiveFunction::partial_dot(w, self.x, j * (p + 1), i),
);
}
prob.softmax_mut();
for j in 0..self.k { for j in 0..self.k {
let yi =(if self.y[i] == j { T::one() } else { T::zero() }) - prob.get(0, j); let yi = (if self.y[i] == j { T::one() } else { T::zero() }) - prob.get(0, j);
for l in 0..p { for l in 0..p {
let pos = j * (p + 1); let pos = j * (p + 1);
g.set(0, pos + l, g.get(0, pos + l) - yi * self.x.get(i, l)); g.set(0, pos + l, g.get(0, pos + l) - yi * self.x.get(i, l));
} }
g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi); g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi);
} }
} }
} }
} }
impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> { impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M> {
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M>{
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let (_, y_nrows) = y_m.shape(); let (_, y_nrows) = y_m.shape();
@@ -158,271 +156,277 @@ impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
if x_nrows != y_nrows { if x_nrows != y_nrows {
panic!("Number of rows of X doesn't match number of rows of Y"); panic!("Number of rows of X doesn't match number of rows of Y");
} }
let classes = y_m.unique();
let k = classes.len(); let classes = y_m.unique();
let k = classes.len();
let mut yi: Vec<usize> = vec![0; y_nrows]; let mut yi: Vec<usize> = vec![0; y_nrows];
for i in 0..y_nrows { for i in 0..y_nrows {
let yc = y_m.get(0, i); let yc = y_m.get(0, i);
yi[i] = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap();
} }
if k < 2 { if k < 2 {
panic!("Incorrect number of classes: {}", k); panic!("Incorrect number of classes: {}", k);
} else if k == 2 { } else if k == 2 {
let x0 = M::zeros(1, num_attributes + 1); let x0 = M::zeros(1, num_attributes + 1);
let objective = BinaryObjectiveFunction{ let objective = BinaryObjectiveFunction {
x: x, x: x,
y: yi, y: yi,
phantom: PhantomData phantom: PhantomData,
}; };
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
LogisticRegression { LogisticRegression {
weights: result.x, weights: result.x,
classes: classes, classes: classes,
num_attributes: num_attributes, num_attributes: num_attributes,
num_classes: k, num_classes: k,
} }
} else { } else {
let x0 = M::zeros(1, (num_attributes + 1) * k); let x0 = M::zeros(1, (num_attributes + 1) * k);
let objective = MultiClassObjectiveFunction{ let objective = MultiClassObjectiveFunction {
x: x, x: x,
y: yi, y: yi,
k: k, k: k,
phantom: PhantomData phantom: PhantomData,
}; };
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
let weights = result.x.reshape(k, num_attributes + 1); let weights = result.x.reshape(k, num_attributes + 1);
LogisticRegression { LogisticRegression {
weights: weights, weights: weights,
classes: classes, classes: classes,
num_attributes: num_attributes, num_attributes: num_attributes,
num_classes: k num_classes: k,
} }
} }
} }
pub fn predict(&self, x: &M) -> M::RowVector { pub fn predict(&self, x: &M) -> M::RowVector {
let n = x.shape().0; let n = x.shape().0;
let mut result = M::zeros(1, n); let mut result = M::zeros(1, n);
if self.num_classes == 2 { if self.num_classes == 2 {
let (nrows, _) = x.shape(); let (nrows, _) = x.shape();
let x_and_bias = x.v_stack(&M::ones(nrows, 1)); let x_and_bias = x.v_stack(&M::ones(nrows, 1));
let y_hat: Vec<T> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector(); let y_hat: Vec<T> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
for i in 0..n { for i in 0..n {
result.set(0, i, self.classes[if y_hat[i].sigmoid() > T::half() { 1 } else { 0 }]); result.set(
} 0,
i,
self.classes[if y_hat[i].sigmoid() > T::half() { 1 } else { 0 }],
);
}
} else { } else {
let (nrows, _) = x.shape(); let (nrows, _) = x.shape();
let x_and_bias = x.v_stack(&M::ones(nrows, 1)); let x_and_bias = x.v_stack(&M::ones(nrows, 1));
let y_hat = x_and_bias.dot(&self.weights.transpose()); let y_hat = x_and_bias.dot(&self.weights.transpose());
let class_idxs = y_hat.argmax(); let class_idxs = y_hat.argmax();
for i in 0..n { for i in 0..n {
result.set(0, i, self.classes[class_idxs[i]]); result.set(0, i, self.classes[class_idxs[i]]);
} }
} }
result.to_row_vector() result.to_row_vector()
} }
pub fn coefficients(&self) -> M { pub fn coefficients(&self) -> M {
self.weights.slice(0..self.num_classes, 0..self.num_attributes) self.weights
.slice(0..self.num_classes, 0..self.num_attributes)
} }
pub fn intercept(&self) -> M { pub fn intercept(&self) -> M {
self.weights.slice(0..self.num_classes, self.num_attributes..self.num_attributes+1) self.weights.slice(
} 0..self.num_classes,
self.num_attributes..self.num_attributes + 1,
)
}
fn minimize(x0: M, objective: impl ObjectiveFunction<T, M>) -> OptimizerResult<T, M> { fn minimize(x0: M, objective: impl ObjectiveFunction<T, M>) -> OptimizerResult<T, M> {
let f = |w: &M| -> T { let f = |w: &M| -> T { objective.f(w) };
objective.f(w)
};
let df = |g: &mut M, w: &M| { let df = |g: &mut M, w: &M| objective.df(g, w);
objective.df(g, w)
};
let mut ls: Backtracking<T> = Default::default(); let mut ls: Backtracking<T> = Default::default();
ls.order = FunctionOrder::THIRD; ls.order = FunctionOrder::THIRD;
let optimizer: LBFGS<T> = Default::default(); let optimizer: LBFGS<T> = Default::default();
optimizer.optimize(&f, &df, &x0, &ls) optimizer.optimize(&f, &df, &x0, &ls)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use ndarray::{arr1, arr2, Array1};
use crate::metrics::*; use crate::metrics::*;
use ndarray::{arr1, arr2, Array1};
#[test] #[test]
fn multiclass_objective_f() { fn multiclass_objective_f() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[1., -5.], &[1., -5.],
&[ 2., 5.], &[2., 5.],
&[ 3., -2.], &[3., -2.],
&[ 1., 2.], &[1., 2.],
&[ 2., 0.], &[2., 0.],
&[ 6., -5.], &[6., -5.],
&[ 7., 5.], &[7., 5.],
&[ 6., -2.], &[6., -2.],
&[ 7., 2.], &[7., 2.],
&[ 6., 0.], &[6., 0.],
&[ 8., -5.], &[8., -5.],
&[ 9., 5.], &[9., 5.],
&[10., -2.], &[10., -2.],
&[ 8., 2.], &[8., 2.],
&[ 9., 0.]]); &[9., 0.],
]);
let y = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1]; let y = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
let objective = MultiClassObjectiveFunction{ let objective = MultiClassObjectiveFunction {
x: &x, x: &x,
y: y, y: y,
k: 3, k: 3,
phantom: PhantomData phantom: PhantomData,
}; };
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9); let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.])); objective.df(
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.])); &mut g,
&DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]),
assert!((g.get(0, 0) + 33.000068218163484).abs() < std::f64::EPSILON); );
objective.df(
&mut g,
&DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]),
);
let f = objective.f(&DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.])); assert!((g.get(0, 0) + 33.000068218163484).abs() < std::f64::EPSILON);
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON); let f = objective.f(&DenseMatrix::vector_from_array(&[
1., 2., 3., 4., 5., 6., 7., 8., 9.,
]));
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
} }
#[test] #[test]
fn binary_objective_f() { fn binary_objective_f() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[1., -5.], &[1., -5.],
&[ 2., 5.], &[2., 5.],
&[ 3., -2.], &[3., -2.],
&[ 1., 2.], &[1., 2.],
&[ 2., 0.], &[2., 0.],
&[ 6., -5.], &[6., -5.],
&[ 7., 5.], &[7., 5.],
&[ 6., -2.], &[6., -2.],
&[ 7., 2.], &[7., 2.],
&[ 6., 0.], &[6., 0.],
&[ 8., -5.], &[8., -5.],
&[ 9., 5.], &[9., 5.],
&[10., -2.], &[10., -2.],
&[ 8., 2.], &[8., 2.],
&[ 9., 0.]]); &[9., 0.],
]);
let y = vec![0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1]; let y = vec![0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1];
let objective = BinaryObjectiveFunction{ let objective = BinaryObjectiveFunction {
x: &x, x: &x,
y: y, y: y,
phantom: PhantomData phantom: PhantomData,
}; };
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3); let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3.])); objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3.]));
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3.])); objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3.]));
assert!((g.get(0, 0) - 26.051064349381285).abs() < std::f64::EPSILON);
assert!((g.get(0, 1) - 10.239000702928523).abs() < std::f64::EPSILON);
assert!((g.get(0, 2) - 3.869294270156324).abs() < std::f64::EPSILON);
let f = objective.f(&DenseMatrix::vector_from_array(&[1., 2., 3.])); assert!((g.get(0, 0) - 26.051064349381285).abs() < std::f64::EPSILON);
assert!((g.get(0, 1) - 10.239000702928523).abs() < std::f64::EPSILON);
assert!((g.get(0, 2) - 3.869294270156324).abs() < std::f64::EPSILON);
assert!((f - 59.76994756647412).abs() < std::f64::EPSILON); let f = objective.f(&DenseMatrix::vector_from_array(&[1., 2., 3.]));
assert!((f - 59.76994756647412).abs() < std::f64::EPSILON);
} }
#[test] #[test]
fn lr_fit_predict() { fn lr_fit_predict() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[1., -5.], &[1., -5.],
&[ 2., 5.], &[2., 5.],
&[ 3., -2.], &[3., -2.],
&[ 1., 2.], &[1., 2.],
&[ 2., 0.], &[2., 0.],
&[ 6., -5.], &[6., -5.],
&[ 7., 5.], &[7., 5.],
&[ 6., -2.], &[6., -2.],
&[ 7., 2.], &[7., 2.],
&[ 6., 0.], &[6., 0.],
&[ 8., -5.], &[8., -5.],
&[ 9., 5.], &[9., 5.],
&[10., -2.], &[10., -2.],
&[ 8., 2.], &[8., 2.],
&[ 9., 0.]]); &[9., 0.],
]);
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.]; let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y); let lr = LogisticRegression::fit(&x, &y);
assert_eq!(lr.coefficients().shape(), (3, 2)); assert_eq!(lr.coefficients().shape(), (3, 2));
assert_eq!(lr.intercept().shape(), (3, 1)); assert_eq!(lr.intercept().shape(), (3, 1));
assert!((lr.coefficients().get(0, 0) - 0.0435).abs() < 1e-4); assert!((lr.coefficients().get(0, 0) - 0.0435).abs() < 1e-4);
assert!((lr.intercept().get(0, 0) - 0.1250).abs() < 1e-4); assert!((lr.intercept().get(0, 0) - 0.1250).abs() < 1e-4);
let y_hat = lr.predict(&x); let y_hat = lr.predict(&x);
assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); assert_eq!(
y_hat,
vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
} );
#[test]
fn serde(){
let x = DenseMatrix::from_array(&[
&[1., -5.],
&[ 2., 5.],
&[ 3., -2.],
&[ 1., 2.],
&[ 2., 0.],
&[ 6., -5.],
&[ 7., 5.],
&[ 6., -2.],
&[ 7., 2.],
&[ 6., 0.],
&[ 8., -5.],
&[ 9., 5.],
&[10., -2.],
&[ 8., 2.],
&[ 9., 0.]]);
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y);
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
} }
#[test] #[test]
fn lr_fit_predict_iris() { fn serde() {
let x = DenseMatrix::from_array(&[
&[1., -5.],
&[2., 5.],
&[3., -2.],
&[1., 2.],
&[2., 0.],
&[6., -5.],
&[7., 5.],
&[6., -2.],
&[7., 2.],
&[6., 0.],
&[8., -5.],
&[9., 5.],
&[10., -2.],
&[8., 2.],
&[9., 0.],
]);
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y);
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
}
#[test]
fn lr_fit_predict_iris() {
let x = arr2(&[ let x = arr2(&[
[5.1, 3.5, 1.4, 0.2], [5.1, 3.5, 1.4, 0.2],
[4.9, 3.0, 1.4, 0.2], [4.9, 3.0, 1.4, 0.2],
@@ -443,17 +447,22 @@ mod tests {
[6.3, 3.3, 4.7, 1.6], [6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1.0], [4.9, 2.4, 3.3, 1.0],
[6.6, 2.9, 4.6, 1.3], [6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4]]); [5.2, 2.7, 3.9, 1.4],
let y: Array1<f64> = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]); ]);
let y: Array1<f64> = arr1(&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]);
let lr = LogisticRegression::fit(&x, &y); let lr = LogisticRegression::fit(&x, &y);
let y_hat = lr.predict(&x); let y_hat = lr.predict(&x);
let error: f64 = y.into_iter().zip(y_hat.into_iter()).map(|(&a, &b)| (a - b).abs()).sum(); let error: f64 = y
.into_iter()
.zip(y_hat.into_iter())
.map(|(&a, &b)| (a - b).abs())
.sum();
assert!(error <= 1.0); assert!(error <= 1.0);
} }
}
}
+1 -1
View File
@@ -1,2 +1,2 @@
pub mod linear_regression; pub mod linear_regression;
pub mod logistic_regression; pub mod logistic_regression;
+13 -19
View File
@@ -1,50 +1,44 @@
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
use super::Distance; use super::Distance;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Euclidian { pub struct Euclidian {}
}
impl Euclidian { impl Euclidian {
pub fn squared_distance<T: FloatExt>(x: &Vec<T>,y: &Vec<T>) -> T { pub fn squared_distance<T: FloatExt>(x: &Vec<T>, y: &Vec<T>) -> T {
if x.len() != y.len() { if x.len() != y.len() {
panic!("Input vector sizes are different."); panic!("Input vector sizes are different.");
} }
let mut sum = T::zero(); let mut sum = T::zero();
for i in 0..x.len() { for i in 0..x.len() {
sum = sum + (x[i] - y[i]).powf(T::two()); sum = sum + (x[i] - y[i]).powf(T::two());
} }
sum sum
} }
} }
impl<T: FloatExt> Distance<Vec<T>, T> for Euclidian { impl<T: FloatExt> Distance<Vec<T>, T> for Euclidian {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
Euclidian::squared_distance(x, y).sqrt() Euclidian::squared_distance(x, y).sqrt()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn squared_distance() { fn squared_distance() {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.]; let b = vec![4., 5., 6.];
let l2: f64 = Euclidian{}.distance(&a, &b);
assert!((l2 - 5.19615242).abs() < 1e-8); let l2: f64 = Euclidian {}.distance(&a, &b);
}
} assert!((l2 - 5.19615242).abs() < 1e-8);
}
}
+15 -20
View File
@@ -1,45 +1,40 @@
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
use super::Distance; use super::Distance;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Hamming { pub struct Hamming {}
}
impl<T: PartialEq, F: FloatExt> Distance<Vec<T>, F> for Hamming { impl<T: PartialEq, F: FloatExt> Distance<Vec<T>, F> for Hamming {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> F {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> F {
if x.len() != y.len() { if x.len() != y.len() {
panic!("Input vector sizes are different"); panic!("Input vector sizes are different");
} }
let mut dist = 0; let mut dist = 0;
for i in 0..x.len() { for i in 0..x.len() {
if x[i] != y[i]{ if x[i] != y[i] {
dist += 1; dist += 1;
} }
} }
F::from_i64(dist).unwrap() / F::from_usize(x.len()).unwrap() F::from_i64(dist).unwrap() / F::from_usize(x.len()).unwrap()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn minkowski_distance() { fn minkowski_distance() {
let a = vec![1, 0, 0, 1, 0, 0, 1]; let a = vec![1, 0, 0, 1, 0, 0, 1];
let b = vec![1, 1, 0, 0, 1, 0, 1]; let b = vec![1, 1, 0, 0, 1, 0, 1];
let h: f64 = Hamming{}.distance(&a, &b); let h: f64 = Hamming {}.distance(&a, &b);
assert!((h - 0.42857142).abs() < 1e-8); assert!((h - 0.42857142).abs() < 1e-8);
} }
}
}
+29 -22
View File
@@ -2,7 +2,7 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
@@ -13,7 +13,7 @@ use crate::linalg::Matrix;
pub struct Mahalanobis<T: FloatExt, M: Matrix<T>> { pub struct Mahalanobis<T: FloatExt, M: Matrix<T>> {
pub sigma: M, pub sigma: M,
pub sigmaInv: M, pub sigmaInv: M,
t: PhantomData<T> t: PhantomData<T>,
} }
impl<T: FloatExt, M: Matrix<T>> Mahalanobis<T, M> { impl<T: FloatExt, M: Matrix<T>> Mahalanobis<T, M> {
@@ -23,7 +23,7 @@ impl<T: FloatExt, M: Matrix<T>> Mahalanobis<T, M> {
Mahalanobis { Mahalanobis {
sigma: sigma, sigma: sigma,
sigmaInv: sigmaInv, sigmaInv: sigmaInv,
t: PhantomData t: PhantomData,
} }
} }
@@ -33,21 +33,30 @@ impl<T: FloatExt, M: Matrix<T>> Mahalanobis<T, M> {
Mahalanobis { Mahalanobis {
sigma: sigma, sigma: sigma,
sigmaInv: sigmaInv, sigmaInv: sigmaInv,
t: PhantomData t: PhantomData,
} }
} }
} }
impl<T: FloatExt, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> { impl<T: FloatExt, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T { fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
let (nrows, ncols) = self.sigma.shape(); let (nrows, ncols) = self.sigma.shape();
if x.len() != nrows { if x.len() != nrows {
panic!("Array x[{}] has different dimension with Sigma[{}][{}].", x.len(), nrows, ncols); panic!(
"Array x[{}] has different dimension with Sigma[{}][{}].",
x.len(),
nrows,
ncols
);
} }
if y.len() != nrows { if y.len() != nrows {
panic!("Array y[{}] has different dimension with Sigma[{}][{}].", y.len(), nrows, ncols); panic!(
"Array y[{}] has different dimension with Sigma[{}][{}].",
y.len(),
nrows,
ncols
);
} }
println!("{}", self.sigmaInv); println!("{}", self.sigmaInv);
@@ -56,7 +65,7 @@ impl<T: FloatExt, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
let mut z = vec![T::zero(); n]; let mut z = vec![T::zero(); n];
for i in 0..n { for i in 0..n {
z[i] = x[i] - y[i]; z[i] = x[i] - y[i];
} }
// np.dot(np.dot((a-b),VI),(a-b).T) // np.dot(np.dot((a-b),VI),(a-b).T)
let mut s = T::zero(); let mut s = T::zero();
@@ -67,31 +76,29 @@ impl<T: FloatExt, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
} }
s.sqrt() s.sqrt()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[test] #[test]
fn mahalanobis_distance() { fn mahalanobis_distance() {
let data = DenseMatrix::from_array(&[ let data = DenseMatrix::from_array(&[
&[ 64., 580., 29.], &[64., 580., 29.],
&[ 66., 570., 33.], &[66., 570., 33.],
&[ 68., 590., 37.], &[68., 590., 37.],
&[ 69., 660., 46.], &[69., 660., 46.],
&[ 73., 600., 55.]]); &[73., 600., 55.],
]);
let a = data.column_mean(); let a = data.column_mean();
let b = vec![66., 640., 44.]; let b = vec![66., 640., 44.];
let mahalanobis = Mahalanobis::new(&data); let mahalanobis = Mahalanobis::new(&data);
println!("{}", mahalanobis.distance(&a, &b)); println!("{}", mahalanobis.distance(&a, &b));
} }
}
}
+15 -20
View File
@@ -1,43 +1,38 @@
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
use super::Distance; use super::Distance;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Manhattan { pub struct Manhattan {}
}
impl<T: FloatExt> Distance<Vec<T>, T> for Manhattan { impl<T: FloatExt> Distance<Vec<T>, T> for Manhattan {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
if x.len() != y.len() { if x.len() != y.len() {
panic!("Input vector sizes are different"); panic!("Input vector sizes are different");
} }
let mut dist = T::zero(); let mut dist = T::zero();
for i in 0..x.len() { for i in 0..x.len() {
dist = dist + (x[i] - y[i]).abs(); dist = dist + (x[i] - y[i]).abs();
} }
dist dist
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn manhattan_distance() { fn manhattan_distance() {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.]; let b = vec![4., 5., 6.];
let l1: f64 = Manhattan{}.distance(&a, &b);
assert!((l1 - 9.0).abs() < 1e-8); let l1: f64 = Manhattan {}.distance(&a, &b);
}
} assert!((l1 - 9.0).abs() < 1e-8);
}
}
+21 -27
View File
@@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
@@ -6,58 +6,52 @@ use super::Distance;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Minkowski<T: FloatExt> { pub struct Minkowski<T: FloatExt> {
pub p: T pub p: T,
} }
impl<T: FloatExt> Distance<Vec<T>, T> for Minkowski<T> { impl<T: FloatExt> Distance<Vec<T>, T> for Minkowski<T> {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
if x.len() != y.len() { if x.len() != y.len() {
panic!("Input vector sizes are different"); panic!("Input vector sizes are different");
} }
if self.p < T::one() { if self.p < T::one() {
panic!("p must be at least 1"); panic!("p must be at least 1");
} }
let mut dist = T::zero(); let mut dist = T::zero();
for i in 0..x.len() { for i in 0..x.len() {
let d = (x[i] - y[i]).abs(); let d = (x[i] - y[i]).abs();
dist = dist + d.powf(self.p); dist = dist + d.powf(self.p);
} }
dist.powf(T::one()/self.p)
}
dist.powf(T::one() / self.p)
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn minkowski_distance() { fn minkowski_distance() {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.]; let b = vec![4., 5., 6.];
let l1: f64 = Minkowski{p: 1.0}.distance(&a, &b); let l1: f64 = Minkowski { p: 1.0 }.distance(&a, &b);
let l2: f64 = Minkowski{p: 2.0}.distance(&a, &b); let l2: f64 = Minkowski { p: 2.0 }.distance(&a, &b);
let l3: f64 = Minkowski{p: 3.0}.distance(&a, &b); let l3: f64 = Minkowski { p: 3.0 }.distance(&a, &b);
assert!((l1 - 9.0).abs() < 1e-8); assert!((l1 - 9.0).abs() < 1e-8);
assert!((l2 - 5.19615242).abs() < 1e-8); assert!((l2 - 5.19615242).abs() < 1e-8);
assert!((l3 - 4.32674871).abs() < 1e-8); assert!((l3 - 4.32674871).abs() < 1e-8);
} }
#[test] #[test]
#[should_panic(expected = "p must be at least 1")] #[should_panic(expected = "p must be at least 1")]
fn minkowski_distance_negative_p() { fn minkowski_distance_negative_p() {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.]; let b = vec![4., 5., 6.];
let _: f64 = Minkowski{p: 0.0}.distance(&a, &b); let _: f64 = Minkowski { p: 0.0 }.distance(&a, &b);
} }
}
}
+10 -11
View File
@@ -1,32 +1,31 @@
pub mod euclidian; pub mod euclidian;
pub mod minkowski;
pub mod manhattan;
pub mod hamming; pub mod hamming;
pub mod mahalanobis; pub mod mahalanobis;
pub mod manhattan;
pub mod minkowski;
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
pub trait Distance<T, F: FloatExt>{ pub trait Distance<T, F: FloatExt> {
fn distance(&self, a: &T, b: &T) -> F; fn distance(&self, a: &T, b: &T) -> F;
} }
pub struct Distances{ pub struct Distances {}
}
impl Distances { impl Distances {
pub fn euclidian() -> euclidian::Euclidian{ pub fn euclidian() -> euclidian::Euclidian {
euclidian::Euclidian {} euclidian::Euclidian {}
} }
pub fn minkowski<T: FloatExt>(p: T) -> minkowski::Minkowski<T>{ pub fn minkowski<T: FloatExt>(p: T) -> minkowski::Minkowski<T> {
minkowski::Minkowski {p: p} minkowski::Minkowski { p: p }
} }
pub fn manhattan() -> manhattan::Manhattan{ pub fn manhattan() -> manhattan::Manhattan {
manhattan::Manhattan {} manhattan::Manhattan {}
} }
pub fn hamming() -> hamming::Hamming{ pub fn hamming() -> hamming::Hamming {
hamming::Hamming {} hamming::Hamming {}
} }
} }
+1 -1
View File
@@ -1,2 +1,2 @@
pub mod distance; pub mod distance;
pub mod num; pub mod num;
+11 -19
View File
@@ -1,9 +1,8 @@
use std::fmt::{Debug, Display};
use num_traits::{Float, FromPrimitive}; use num_traits::{Float, FromPrimitive};
use rand::prelude::*; use rand::prelude::*;
use std::fmt::{Debug, Display};
pub trait FloatExt: Float + FromPrimitive + Debug + Display + Copy { pub trait FloatExt: Float + FromPrimitive + Debug + Display + Copy {
fn copysign(self, sign: Self) -> Self; fn copysign(self, sign: Self) -> Self;
fn ln_1pe(self) -> Self; fn ln_1pe(self) -> Self;
@@ -15,33 +14,29 @@ pub trait FloatExt: Float + FromPrimitive + Debug + Display + Copy {
fn two() -> Self; fn two() -> Self;
fn half() -> Self; fn half() -> Self;
} }
impl FloatExt for f64 { impl FloatExt for f64 {
fn copysign(self, sign: Self) -> Self{ fn copysign(self, sign: Self) -> Self {
self.copysign(sign) self.copysign(sign)
} }
fn ln_1pe(self) -> f64{ fn ln_1pe(self) -> f64 {
if self > 15. { if self > 15. {
return self; return self;
} else { } else {
return self.exp().ln_1p(); return self.exp().ln_1p();
} }
} }
fn sigmoid(self) -> f64 { fn sigmoid(self) -> f64 {
if self < -40. { if self < -40. {
return 0.; return 0.;
} else if self > 40. { } else if self > 40. {
return 1.; return 1.;
} else { } else {
return 1. / (1. + f64::exp(-self)) return 1. / (1. + f64::exp(-self));
} }
} }
fn rand() -> f64 { fn rand() -> f64 {
@@ -59,29 +54,26 @@ impl FloatExt for f64 {
} }
impl FloatExt for f32 { impl FloatExt for f32 {
fn copysign(self, sign: Self) -> Self{ fn copysign(self, sign: Self) -> Self {
self.copysign(sign) self.copysign(sign)
} }
fn ln_1pe(self) -> f32{ fn ln_1pe(self) -> f32 {
if self > 15. { if self > 15. {
return self; return self;
} else { } else {
return self.exp().ln_1p(); return self.exp().ln_1p();
} }
} }
fn sigmoid(self) -> f32 { fn sigmoid(self) -> f32 {
if self < -40. { if self < -40. {
return 0.; return 0.;
} else if self > 40. { } else if self > 40. {
return 1.; return 1.;
} else { } else {
return 1. / (1. + f32::exp(-self)) return 1. / (1. + f32::exp(-self));
} }
} }
fn rand() -> f32 { fn rand() -> f32 {
@@ -99,13 +91,13 @@ impl FloatExt for f32 {
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn sigmoid() { fn sigmoid() {
assert_eq!(1.0.sigmoid(), 0.7310585786300049); assert_eq!(1.0.sigmoid(), 0.7310585786300049);
assert_eq!(41.0.sigmoid(), 1.); assert_eq!(41.0.sigmoid(), 1.);
assert_eq!((-41.0).sigmoid(), 0.); assert_eq!((-41.0).sigmoid(), 0.);
} }
} }
+14 -12
View File
@@ -1,15 +1,19 @@
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::FloatExt;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Accuracy{} pub struct Accuracy {}
impl Accuracy { impl Accuracy {
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T { pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
if y_true.len() != y_prod.len() { if y_true.len() != y_prod.len() {
panic!("The vector sizes don't match: {} != {}", y_true.len(), y_prod.len()); panic!(
"The vector sizes don't match: {} != {}",
y_true.len(),
y_prod.len()
);
} }
let n = y_true.len(); let n = y_true.len();
@@ -23,23 +27,21 @@ impl Accuracy {
T::from_i64(positive).unwrap() / T::from_usize(n).unwrap() T::from_i64(positive).unwrap() / T::from_usize(n).unwrap()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn accuracy() { fn accuracy() {
let y_pred: Vec<f64> = vec![0., 2., 1., 3.]; let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
let y_true: Vec<f64> = vec![0., 1., 2., 3.]; let y_true: Vec<f64> = vec![0., 1., 2., 3.];
let score1: f64 = Accuracy{}.get_score(&y_pred, &y_true); let score1: f64 = Accuracy {}.get_score(&y_pred, &y_true);
let score2: f64 = Accuracy{}.get_score(&y_true, &y_true); let score2: f64 = Accuracy {}.get_score(&y_true, &y_true);
assert!((score1 - 0.5).abs() < 1e-8); assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8); assert!((score2 - 1.0).abs() < 1e-8);
} }
}
}
+11 -11
View File
@@ -1,34 +1,34 @@
pub mod accuracy; pub mod accuracy;
pub mod recall;
pub mod precision; pub mod precision;
pub mod recall;
use crate::math::num::FloatExt;
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::FloatExt;
pub struct ClassificationMetrics{} pub struct ClassificationMetrics {}
impl ClassificationMetrics { impl ClassificationMetrics {
pub fn accuracy() -> accuracy::Accuracy{ pub fn accuracy() -> accuracy::Accuracy {
accuracy::Accuracy {} accuracy::Accuracy {}
} }
pub fn recall() -> recall::Recall{ pub fn recall() -> recall::Recall {
recall::Recall {} recall::Recall {}
} }
pub fn precision() -> precision::Precision{ pub fn precision() -> precision::Precision {
precision::Precision {} precision::Precision {}
} }
} }
pub fn accuracy<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_prod: &V) -> T{ pub fn accuracy<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_prod: &V) -> T {
ClassificationMetrics::accuracy().get_score(y_true, y_prod) ClassificationMetrics::accuracy().get_score(y_true, y_prod)
} }
pub fn recall<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_prod: &V) -> T{ pub fn recall<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_prod: &V) -> T {
ClassificationMetrics::recall().get_score(y_true, y_prod) ClassificationMetrics::recall().get_score(y_true, y_prod)
} }
pub fn precision<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_prod: &V) -> T{ pub fn precision<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_prod: &V) -> T {
ClassificationMetrics::precision().get_score(y_true, y_prod) ClassificationMetrics::precision().get_score(y_true, y_prod)
} }
+25 -17
View File
@@ -1,15 +1,19 @@
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::FloatExt;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Precision{} pub struct Precision {}
impl Precision { impl Precision {
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T { pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
if y_true.len() != y_prod.len() { if y_true.len() != y_prod.len() {
panic!("The vector sizes don't match: {} != {}", y_true.len(), y_prod.len()); panic!(
"The vector sizes don't match: {} != {}",
y_true.len(),
y_prod.len()
);
} }
let mut tp = 0; let mut tp = 0;
@@ -17,11 +21,17 @@ impl Precision {
let n = y_true.len(); let n = y_true.len();
for i in 0..n { for i in 0..n {
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() { if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
panic!("Precision can only be applied to binary classification: {}", y_true.get(i)); panic!(
"Precision can only be applied to binary classification: {}",
y_true.get(i)
);
} }
if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() { if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() {
panic!("Precision can only be applied to binary classification: {}", y_prod.get(i)); panic!(
"Precision can only be applied to binary classification: {}",
y_prod.get(i)
);
} }
if y_prod.get(i) == T::one() { if y_prod.get(i) == T::one() {
@@ -31,27 +41,25 @@ impl Precision {
tp += 1; tp += 1;
} }
} }
} }
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap() T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn precision() { fn precision() {
let y_true: Vec<f64> = vec![0., 1., 1., 0.]; let y_true: Vec<f64> = vec![0., 1., 1., 0.];
let y_pred: Vec<f64> = vec![0., 0., 1., 1.]; let y_pred: Vec<f64> = vec![0., 0., 1., 1.];
let score1: f64 = Precision{}.get_score(&y_pred, &y_true);
let score2: f64 = Precision{}.get_score(&y_pred, &y_pred);
assert!((score1 - 0.5).abs() < 1e-8); let score1: f64 = Precision {}.get_score(&y_pred, &y_true);
let score2: f64 = Precision {}.get_score(&y_pred, &y_pred);
assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8); assert!((score2 - 1.0).abs() < 1e-8);
} }
}
}
+25 -17
View File
@@ -1,15 +1,19 @@
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::FloatExt;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Recall{} pub struct Recall {}
impl Recall { impl Recall {
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T { pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
if y_true.len() != y_prod.len() { if y_true.len() != y_prod.len() {
panic!("The vector sizes don't match: {} != {}", y_true.len(), y_prod.len()); panic!(
"The vector sizes don't match: {} != {}",
y_true.len(),
y_prod.len()
);
} }
let mut tp = 0; let mut tp = 0;
@@ -17,11 +21,17 @@ impl Recall {
let n = y_true.len(); let n = y_true.len();
for i in 0..n { for i in 0..n {
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() { if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
panic!("Recall can only be applied to binary classification: {}", y_true.get(i)); panic!(
"Recall can only be applied to binary classification: {}",
y_true.get(i)
);
} }
if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() { if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() {
panic!("Recall can only be applied to binary classification: {}", y_prod.get(i)); panic!(
"Recall can only be applied to binary classification: {}",
y_prod.get(i)
);
} }
if y_true.get(i) == T::one() { if y_true.get(i) == T::one() {
@@ -31,27 +41,25 @@ impl Recall {
tp += 1; tp += 1;
} }
} }
} }
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap() T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn recall() { fn recall() {
let y_true: Vec<f64> = vec![0., 1., 1., 0.]; let y_true: Vec<f64> = vec![0., 1., 1., 0.];
let y_pred: Vec<f64> = vec![0., 0., 1., 1.]; let y_pred: Vec<f64> = vec![0., 0., 1., 1.];
let score1: f64 = Recall{}.get_score(&y_pred, &y_true);
let score2: f64 = Recall{}.get_score(&y_pred, &y_pred);
assert!((score1 - 0.5).abs() < 1e-8); let score1: f64 = Recall {}.get_score(&y_pred, &y_true);
let score2: f64 = Recall {}.get_score(&y_pred, &y_pred);
assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8); assert!((score2 - 1.0).abs() < 1e-8);
} }
}
}
+89 -71
View File
@@ -1,66 +1,70 @@
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::math::distance::Distance;
use crate::linalg::{Matrix, row_iter};
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::algorithm::neighbour::cover_tree::CoverTree; use crate::algorithm::neighbour::cover_tree::CoverTree;
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::linalg::{row_iter, Matrix};
use crate::math::distance::Distance;
use crate::math::num::FloatExt;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> { pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
classes: Vec<T>, classes: Vec<T>,
y: Vec<usize>, y: Vec<usize>,
knn_algorithm: KNNAlgorithmV<T, D>, knn_algorithm: KNNAlgorithmV<T, D>,
k: usize k: usize,
} }
pub enum KNNAlgorithmName { pub enum KNNAlgorithmName {
LinearSearch, LinearSearch,
CoverTree CoverTree,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub enum KNNAlgorithmV<T: FloatExt, D: Distance<Vec<T>, T>> { pub enum KNNAlgorithmV<T: FloatExt, D: Distance<Vec<T>, T>> {
LinearSearch(LinearKNNSearch<Vec<T>, T, D>), LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
CoverTree(CoverTree<Vec<T>, T, D>) CoverTree(CoverTree<Vec<T>, T, D>),
} }
impl KNNAlgorithmName { impl KNNAlgorithmName {
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(&self, data: Vec<Vec<T>>, distance: D) -> KNNAlgorithmV<T, D> { &self,
data: Vec<Vec<T>>,
distance: D,
) -> KNNAlgorithmV<T, D> {
match *self { match *self {
KNNAlgorithmName::LinearSearch => KNNAlgorithmV::LinearSearch(LinearKNNSearch::new(data, distance)), KNNAlgorithmName::LinearSearch => {
KNNAlgorithmV::LinearSearch(LinearKNNSearch::new(data, distance))
}
KNNAlgorithmName::CoverTree => KNNAlgorithmV::CoverTree(CoverTree::new(data, distance)), KNNAlgorithmName::CoverTree => KNNAlgorithmV::CoverTree(CoverTree::new(data, distance)),
} }
} }
} }
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithmV<T, D> { impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithmV<T, D> {
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize>{ fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize> {
match *self { match *self {
KNNAlgorithmV::LinearSearch(ref linear) => linear.find(from, k), KNNAlgorithmV::LinearSearch(ref linear) => linear.find(from, k),
KNNAlgorithmV::CoverTree(ref cover) => cover.find(from, k) KNNAlgorithmV::CoverTree(ref cover) => cover.find(from, k),
} }
} }
} }
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.classes.len() != other.classes.len() || if self.classes.len() != other.classes.len()
self.k != other.k || || self.k != other.k
self.y.len() != other.y.len() { || self.y.len() != other.y.len()
return false {
return false;
} else { } else {
for i in 0..self.classes.len() { for i in 0..self.classes.len() {
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() { if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
return false return false;
} }
} }
for i in 0..self.y.len() { for i in 0..self.y.len() {
if self.y[i] != other.y[i] { if self.y[i] != other.y[i] {
return false return false;
} }
} }
true true
@@ -69,96 +73,110 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
} }
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> { impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
pub fn fit<M: Matrix<T>>(
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, k: usize, distance: D, algorithm: KNNAlgorithmName) -> KNNClassifier<T, D> { x: &M,
y: &M::RowVector,
k: usize,
distance: D,
algorithm: KNNAlgorithmName,
) -> KNNClassifier<T, D> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_n) = y_m.shape(); let (_, y_n) = y_m.shape();
let (x_n, _) = x.shape(); let (x_n, _) = x.shape();
let data = row_iter(x).collect(); let data = row_iter(x).collect();
let mut yi: Vec<usize> = vec![0; y_n]; let mut yi: Vec<usize> = vec![0; y_n];
let classes = y_m.unique(); let classes = y_m.unique();
for i in 0..y_n { for i in 0..y_n {
let yc = y_m.get(0, i); let yc = y_m.get(0, i);
yi[i] = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap();
} }
assert!(x_n == y_n, format!("Size of x should equal size of y; |x|=[{}], |y|=[{}]", x_n, y_n)); assert!(
x_n == y_n,
format!(
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
x_n, y_n
)
);
assert!(k > 1, format!("k should be > 1, k=[{}]", k)); assert!(k > 1, format!("k should be > 1, k=[{}]", k));
KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: algorithm.fit(data, distance)}
KNNClassifier {
classes: classes,
y: yi,
k: k,
knn_algorithm: algorithm.fit(data, distance),
}
} }
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
row_iter(x).enumerate().for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)])); row_iter(x)
.enumerate()
.for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)]));
result.to_row_vector() result.to_row_vector()
} }
fn predict_for_row(&self, x: Vec<T>) -> usize { fn predict_for_row(&self, x: Vec<T>) -> usize {
let idxs = self.knn_algorithm.find(&x, self.k);
let idxs = self.knn_algorithm.find(&x, self.k);
let mut c = vec![0; self.classes.len()]; let mut c = vec![0; self.classes.len()];
let mut max_c = 0; let mut max_c = 0;
let mut max_i = 0; let mut max_i = 0;
for i in idxs { for i in idxs {
c[self.y[i]] += 1; c[self.y[i]] += 1;
if c[self.y[i]] > max_c { if c[self.y[i]] > max_c {
max_c = c[self.y[i]]; max_c = c[self.y[i]];
max_i = self.y[i]; max_i = self.y[i];
} }
} }
max_i max_i
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::math::distance::Distances;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::math::distance::Distances;
#[test] #[test]
fn knn_fit_predict() { fn knn_fit_predict() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
&[1., 2.], let y = vec![2., 2., 2., 3., 3.];
&[3., 4.], let knn = KNNClassifier::fit(
&[5., 6.], &x,
&[7., 8.], &y,
&[9., 10.]]); 3,
let y = vec![2., 2., 2., 3., 3.]; Distances::euclidian(),
let knn = KNNClassifier::fit(&x, &y, 3, Distances::euclidian(), KNNAlgorithmName::LinearSearch); KNNAlgorithmName::LinearSearch,
);
let r = knn.predict(&x); let r = knn.predict(&x);
assert_eq!(5, Vec::len(&r)); assert_eq!(5, Vec::len(&r));
assert_eq!(y.to_vec(), r); assert_eq!(y.to_vec(), r);
} }
#[test] #[test]
fn serde() { fn serde() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
&[1., 2.], let y = vec![2., 2., 2., 3., 3.];
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.]]);
let y = vec![2., 2., 2., 3., 3.];
let knn = KNNClassifier::fit(&x, &y, 3, Distances::euclidian(), KNNAlgorithmName::CoverTree); let knn = KNNClassifier::fit(
&x,
&y,
3,
Distances::euclidian(),
KNNAlgorithmName::CoverTree,
);
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap(); let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
assert_eq!(knn, deserialized_knn); assert_eq!(knn, deserialized_knn);
} }
} }
+1 -1
View File
@@ -1 +1 @@
pub mod knn; pub mod knn;
@@ -1,15 +1,15 @@
use std::default::Default; use std::default::Default;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::optimization::{F, DF}; use crate::math::num::FloatExt;
use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult}; use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::{DF, F};
pub struct GradientDescent<T: FloatExt> { pub struct GradientDescent<T: FloatExt> {
pub max_iter: usize, pub max_iter: usize,
pub g_rtol: T, pub g_rtol: T,
pub g_atol: T pub g_atol: T,
} }
impl<T: FloatExt> Default for GradientDescent<T> { impl<T: FloatExt> Default for GradientDescent<T> {
@@ -17,31 +17,34 @@ impl<T: FloatExt> Default for GradientDescent<T> {
GradientDescent { GradientDescent {
max_iter: 10000, max_iter: 10000,
g_rtol: T::epsilon().sqrt(), g_rtol: T::epsilon().sqrt(),
g_atol: T::epsilon() g_atol: T::epsilon(),
} }
} }
} }
impl<T: FloatExt> FirstOrderOptimizer<T> for GradientDescent<T> impl<T: FloatExt> FirstOrderOptimizer<T> for GradientDescent<T> {
{ fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
&self,
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &'a F<T, X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<T, X> { f: &'a F<T, X>,
df: &'a DF<X>,
let mut x = x0.clone(); x0: &X,
ls: &'a LS,
) -> OptimizerResult<T, X> {
let mut x = x0.clone();
let mut fx = f(&x); let mut fx = f(&x);
let mut gvec = x0.clone(); let mut gvec = x0.clone();
let mut gnorm = gvec.norm2(); let mut gnorm = gvec.norm2();
let gtol = (gvec.norm2() * self.g_rtol).max(self.g_atol); let gtol = (gvec.norm2() * self.g_rtol).max(self.g_atol);
let mut iter = 0; let mut iter = 0;
let mut alpha = T::one(); let mut alpha = T::one();
df(&mut gvec, &x); df(&mut gvec, &x);
while iter < self.max_iter && (iter == 0 || gnorm > gtol) { while iter < self.max_iter && (iter == 0 || gnorm > gtol) {
iter += 1; iter += 1;
let mut step = gvec.negative(); let mut step = gvec.negative();
let f_alpha = |alpha: T| -> T { let f_alpha = |alpha: T| -> T {
@@ -50,7 +53,7 @@ impl<T: FloatExt> FirstOrderOptimizer<T> for GradientDescent<T>
f(&dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha) f(&dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha)
}; };
let df_alpha = |alpha: T| -> T { let df_alpha = |alpha: T| -> T {
let mut dx = step.clone(); let mut dx = step.clone();
let mut dg = gvec.clone(); let mut dg = gvec.clone();
dx.mul_scalar_mut(alpha); dx.mul_scalar_mut(alpha);
@@ -58,56 +61,58 @@ impl<T: FloatExt> FirstOrderOptimizer<T> for GradientDescent<T>
gvec.vector_dot(&dg) gvec.vector_dot(&dg)
}; };
let df0 = step.vector_dot(&gvec); let df0 = step.vector_dot(&gvec);
let ls_r = ls.search(&f_alpha, &df_alpha, alpha, fx, df0); let ls_r = ls.search(&f_alpha, &df_alpha, alpha, fx, df0);
alpha = ls_r.alpha; alpha = ls_r.alpha;
fx = ls_r.f_x; fx = ls_r.f_x;
x.add_mut(&step.mul_scalar_mut(alpha)); x.add_mut(&step.mul_scalar_mut(alpha));
df(&mut gvec, &x); df(&mut gvec, &x);
gnorm = gvec.norm2(); gnorm = gvec.norm2();
} }
let f_x = f(&x); let f_x = f(&x);
OptimizerResult{ OptimizerResult {
x: x, x: x,
f_x: f_x, f_x: f_x,
iterations: iter iterations: iter,
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::optimization::line_search::Backtracking; use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
#[test] #[test]
fn gradient_descent() { fn gradient_descent() {
let x0 = DenseMatrix::vector_from_array(&[-1., 1.]); let x0 = DenseMatrix::vector_from_array(&[-1., 1.]);
let f = |x: &DenseMatrix<f64>| { let f = |x: &DenseMatrix<f64>| {
(1.0 - x.get(0, 0)).powf(2.) + 100.0 * (x.get(0, 1) - x.get(0, 0).powf(2.)).powf(2.) (1.0 - x.get(0, 0)).powf(2.) + 100.0 * (x.get(0, 1) - x.get(0, 0).powf(2.)).powf(2.)
}; };
let df = |g: &mut DenseMatrix<f64>, x: &DenseMatrix<f64>| { let df = |g: &mut DenseMatrix<f64>, x: &DenseMatrix<f64>| {
g.set(0, 0, -2. * (1. - x.get(0, 0)) - 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0)); g.set(
g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.))); 0,
0,
-2. * (1. - x.get(0, 0))
- 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0),
);
g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.)));
}; };
let mut ls: Backtracking<f64> = Default::default(); let mut ls: Backtracking<f64> = Default::default();
ls.order = FunctionOrder::THIRD; ls.order = FunctionOrder::THIRD;
let optimizer: GradientDescent<f64> = Default::default(); let optimizer: GradientDescent<f64> = Default::default();
let result = optimizer.optimize(&f, &df, &x0, &ls); let result = optimizer.optimize(&f, &df, &x0, &ls);
assert!((result.f_x - 0.0).abs() < 1e-5); assert!((result.f_x - 0.0).abs() < 1e-5);
assert!((result.x.get(0, 0) - 1.0).abs() < 1e-2); assert!((result.x.get(0, 0) - 1.0).abs() < 1e-2);
assert!((result.x.get(0, 1) - 1.0).abs() < 1e-2); assert!((result.x.get(0, 1) - 1.0).abs() < 1e-2);
} }
}
}
+94 -83
View File
@@ -1,26 +1,26 @@
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::optimization::{F, DF}; use crate::math::num::FloatExt;
use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult}; use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::{DF, F};
pub struct LBFGS<T: FloatExt> { pub struct LBFGS<T: FloatExt> {
pub max_iter: usize, pub max_iter: usize,
pub g_rtol: T, pub g_rtol: T,
pub g_atol: T, pub g_atol: T,
pub x_atol: T, pub x_atol: T,
pub x_rtol: T, pub x_rtol: T,
pub f_abstol: T, pub f_abstol: T,
pub f_reltol: T, pub f_reltol: T,
pub successive_f_tol: usize, pub successive_f_tol: usize,
pub m: usize pub m: usize,
} }
impl<T: FloatExt> Default for LBFGS<T> { impl<T: FloatExt> Default for LBFGS<T> {
fn default() -> Self { fn default() -> Self {
LBFGS { LBFGS {
max_iter: 1000, max_iter: 1000,
g_rtol: T::from(1e-8).unwrap(), g_rtol: T::from(1e-8).unwrap(),
@@ -30,48 +30,49 @@ impl<T: FloatExt> Default for LBFGS<T> {
f_abstol: T::zero(), f_abstol: T::zero(),
f_reltol: T::zero(), f_reltol: T::zero(),
successive_f_tol: 1, successive_f_tol: 1,
m: 10 m: 10,
} }
} }
} }
impl<T: FloatExt> LBFGS<T> { impl<T: FloatExt> LBFGS<T> {
fn two_loops<X: Matrix<T>>(&self, state: &mut LBFGSState<T, X>) {
fn two_loops<X: Matrix<T>>(&self, state: &mut LBFGSState<T, X>) {
let lower = state.iteration.max(self.m) - self.m; let lower = state.iteration.max(self.m) - self.m;
let upper = state.iteration; let upper = state.iteration;
state.twoloop_q.copy_from(&state.x_df); state.twoloop_q.copy_from(&state.x_df);
for index in (lower..upper).rev() { for index in (lower..upper).rev() {
let i = index.rem_euclid(self.m); let i = index.rem_euclid(self.m);
let dgi = &state.dg_history[i]; let dgi = &state.dg_history[i];
let dxi = &state.dx_history[i]; let dxi = &state.dx_history[i];
state.twoloop_alpha[i] = state.rho[i] * dxi.vector_dot(&state.twoloop_q); state.twoloop_alpha[i] = state.rho[i] * dxi.vector_dot(&state.twoloop_q);
state.twoloop_q.sub_mut(&dgi.mul_scalar(state.twoloop_alpha[i])); state
} .twoloop_q
.sub_mut(&dgi.mul_scalar(state.twoloop_alpha[i]));
}
if state.iteration > 0 { if state.iteration > 0 {
let i = (upper - 1).rem_euclid(self.m); let i = (upper - 1).rem_euclid(self.m);
let dxi = &state.dx_history[i]; let dxi = &state.dx_history[i];
let dgi = &state.dg_history[i]; let dgi = &state.dg_history[i];
let scaling = dxi.vector_dot(dgi) / dgi.abs().pow_mut(T::two()).sum(); let scaling = dxi.vector_dot(dgi) / dgi.abs().pow_mut(T::two()).sum();
state.s.copy_from(&state.twoloop_q.mul_scalar(scaling)); state.s.copy_from(&state.twoloop_q.mul_scalar(scaling));
} else { } else {
state.s.copy_from(&state.twoloop_q); state.s.copy_from(&state.twoloop_q);
} }
for index in lower..upper { for index in lower..upper {
let i = index.rem_euclid(self.m); let i = index.rem_euclid(self.m);
let dgi = &state.dg_history[i]; let dgi = &state.dg_history[i];
let dxi = &state.dx_history[i]; let dxi = &state.dx_history[i];
let beta = state.rho[i] * dgi.vector_dot(&state.s); let beta = state.rho[i] * dgi.vector_dot(&state.s);
state.s.add_mut(&dxi.mul_scalar(state.twoloop_alpha[i] - beta)); state
} .s
.add_mut(&dxi.mul_scalar(state.twoloop_alpha[i] - beta));
}
state.s.mul_scalar_mut(-T::one()); state.s.mul_scalar_mut(-T::one());
} }
fn init_state<X: Matrix<T>>(&self, x: &X) -> LBFGSState<T, X> { fn init_state<X: Matrix<T>>(&self, x: &X) -> LBFGSState<T, X> {
@@ -80,31 +81,37 @@ impl<T: FloatExt> LBFGS<T> {
x_prev: x.clone(), x_prev: x.clone(),
x_f: T::nan(), x_f: T::nan(),
x_f_prev: T::nan(), x_f_prev: T::nan(),
x_df: x.clone(), x_df: x.clone(),
x_df_prev: x.clone(), x_df_prev: x.clone(),
rho: vec![T::zero(); self.m], rho: vec![T::zero(); self.m],
dx_history: vec![x.clone(); self.m], dx_history: vec![x.clone(); self.m],
dg_history: vec![x.clone(); self.m], dg_history: vec![x.clone(); self.m],
dx: x.clone(), dx: x.clone(),
dg: x.clone(), dg: x.clone(),
twoloop_q: x.clone(), twoloop_q: x.clone(),
twoloop_alpha: vec![T::zero(); self.m], twoloop_alpha: vec![T::zero(); self.m],
iteration: 0, iteration: 0,
counter_f_tol: 0, counter_f_tol: 0,
s: x.clone(), s: x.clone(),
alpha: T::one() alpha: T::one(),
} }
} }
fn update_state<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &'a F<T, X>, df: &'a DF<X>, ls: &'a LS, state: &mut LBFGSState<T, X>) { fn update_state<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
self.two_loops(state); &self,
f: &'a F<T, X>,
df: &'a DF<X>,
ls: &'a LS,
state: &mut LBFGSState<T, X>,
) {
self.two_loops(state);
df(&mut state.x_df_prev, &state.x); df(&mut state.x_df_prev, &state.x);
state.x_f_prev = f(&state.x); state.x_f_prev = f(&state.x);
state.x_prev.copy_from(&state.x); state.x_prev.copy_from(&state.x);
let df0 = state.x_df.vector_dot(&state.s); let df0 = state.x_df.vector_dot(&state.s);
let f_alpha = |alpha: T| -> T { let f_alpha = |alpha: T| -> T {
let mut dx = state.s.clone(); let mut dx = state.s.clone();
@@ -112,22 +119,21 @@ impl<T: FloatExt> LBFGS<T> {
f(&dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha) f(&dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha)
}; };
let df_alpha = |alpha: T| -> T { let df_alpha = |alpha: T| -> T {
let mut dx = state.s.clone(); let mut dx = state.s.clone();
let mut dg = state.x_df.clone(); let mut dg = state.x_df.clone();
dx.mul_scalar_mut(alpha); dx.mul_scalar_mut(alpha);
df(&mut dg, &dx.add_mut(&state.x)); //df(x) = df(x .+ gvec .* alpha) df(&mut dg, &dx.add_mut(&state.x)); //df(x) = df(x .+ gvec .* alpha)
state.x_df.vector_dot(&dg) state.x_df.vector_dot(&dg)
}; };
let ls_r = ls.search(&f_alpha, &df_alpha, T::one(), state.x_f_prev, df0); let ls_r = ls.search(&f_alpha, &df_alpha, T::one(), state.x_f_prev, df0);
state.alpha = ls_r.alpha; state.alpha = ls_r.alpha;
state.dx.copy_from(state.s.mul_scalar_mut(state.alpha)); state.dx.copy_from(state.s.mul_scalar_mut(state.alpha));
state.x.add_mut(&state.dx); state.x.add_mut(&state.dx);
state.x_f = f(&state.x); state.x_f = f(&state.x);
df(&mut state.x_df, &state.x); df(&mut state.x_df, &state.x);
} }
fn assess_convergence<X: Matrix<T>>(&self, state: &mut LBFGSState<T, X>) -> bool { fn assess_convergence<X: Matrix<T>>(&self, state: &mut LBFGSState<T, X>) -> bool {
@@ -139,9 +145,9 @@ impl<T: FloatExt> LBFGS<T> {
if state.x.max_diff(&state.x_prev) <= self.x_rtol * state.x.norm(T::infinity()) { if state.x.max_diff(&state.x_prev) <= self.x_rtol * state.x.norm(T::infinity()) {
x_converged = true; x_converged = true;
} }
if (state.x_f - state.x_f_prev).abs() <= self.f_abstol { if (state.x_f - state.x_f_prev).abs() <= self.f_abstol {
state.counter_f_tol += 1; state.counter_f_tol += 1;
} }
@@ -151,20 +157,20 @@ impl<T: FloatExt> LBFGS<T> {
if state.x_df.norm(T::infinity()) <= self.g_atol { if state.x_df.norm(T::infinity()) <= self.g_atol {
g_converged = true; g_converged = true;
} }
g_converged || x_converged || state.counter_f_tol > self.successive_f_tol g_converged || x_converged || state.counter_f_tol > self.successive_f_tol
} }
fn update_hessian<'a, X: Matrix<T>>(&self, _: &'a DF<X>, state: &mut LBFGSState<T, X>) { fn update_hessian<'a, X: Matrix<T>>(&self, _: &'a DF<X>, state: &mut LBFGSState<T, X>) {
state.dg = state.x_df.sub(&state.x_df_prev); state.dg = state.x_df.sub(&state.x_df_prev);
let rho_iteration = T::one() / state.dx.vector_dot(&state.dg); let rho_iteration = T::one() / state.dx.vector_dot(&state.dg);
if !rho_iteration.is_infinite() { if !rho_iteration.is_infinite() {
let idx = state.iteration.rem_euclid(self.m); let idx = state.iteration.rem_euclid(self.m);
state.dx_history[idx].copy_from(&state.dx); state.dx_history[idx].copy_from(&state.dx);
state.dg_history[idx].copy_from(&state.dg); state.dg_history[idx].copy_from(&state.dg);
state.rho[idx] = rho_iteration; state.rho[idx] = rho_iteration;
} }
} }
} }
@@ -174,84 +180,89 @@ struct LBFGSState<T: FloatExt, X: Matrix<T>> {
x_prev: X, x_prev: X,
x_f: T, x_f: T,
x_f_prev: T, x_f_prev: T,
x_df: X, x_df: X,
x_df_prev: X, x_df_prev: X,
rho: Vec<T>, rho: Vec<T>,
dx_history: Vec<X>, dx_history: Vec<X>,
dg_history: Vec<X>, dg_history: Vec<X>,
dx: X, dx: X,
dg: X, dg: X,
twoloop_q: X, twoloop_q: X,
twoloop_alpha: Vec<T>, twoloop_alpha: Vec<T>,
iteration: usize, iteration: usize,
counter_f_tol: usize, counter_f_tol: usize,
s: X, s: X,
alpha: T alpha: T,
} }
impl<T: FloatExt> FirstOrderOptimizer<T> for LBFGS<T> { impl<T: FloatExt> FirstOrderOptimizer<T> for LBFGS<T> {
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &F<T, X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<T, X> { &self,
f: &F<T, X>,
df: &'a DF<X>,
x0: &X,
ls: &'a LS,
) -> OptimizerResult<T, X> {
let mut state = self.init_state(x0); let mut state = self.init_state(x0);
df(&mut state.x_df, &x0); df(&mut state.x_df, &x0);
let g_converged = state.x_df.norm(T::infinity()) < self.g_atol; let g_converged = state.x_df.norm(T::infinity()) < self.g_atol;
let mut converged = g_converged; let mut converged = g_converged;
let stopped = false; let stopped = false;
while !converged && !stopped && state.iteration < self.max_iter { while !converged && !stopped && state.iteration < self.max_iter {
self.update_state(f, df, ls, &mut state);
self.update_state(f, df, ls, &mut state);
converged = self.assess_convergence(&mut state); converged = self.assess_convergence(&mut state);
if !converged { if !converged {
self.update_hessian(df, &mut state); self.update_hessian(df, &mut state);
} }
state.iteration += 1; state.iteration += 1;
}
} OptimizerResult {
OptimizerResult{
x: state.x, x: state.x,
f_x: state.x_f, f_x: state.x_f,
iterations: state.iteration iterations: state.iteration,
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::optimization::line_search::Backtracking; use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
#[test] #[test]
fn lbfgs() { fn lbfgs() {
let x0 = DenseMatrix::vector_from_array(&[0., 0.]); let x0 = DenseMatrix::vector_from_array(&[0., 0.]);
let f = |x: &DenseMatrix<f64>| { let f = |x: &DenseMatrix<f64>| {
(1.0 - x.get(0, 0)).powf(2.) + 100.0 * (x.get(0, 1) - x.get(0, 0).powf(2.)).powf(2.) (1.0 - x.get(0, 0)).powf(2.) + 100.0 * (x.get(0, 1) - x.get(0, 0).powf(2.)).powf(2.)
}; };
let df = |g: &mut DenseMatrix<f64>, x: &DenseMatrix<f64>| { let df = |g: &mut DenseMatrix<f64>, x: &DenseMatrix<f64>| {
g.set(0, 0, -2. * (1. - x.get(0, 0)) - 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0)); g.set(
g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.))); 0,
0,
-2. * (1. - x.get(0, 0))
- 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0),
);
g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.)));
}; };
let mut ls: Backtracking<f64> = Default::default(); let mut ls: Backtracking<f64> = Default::default();
ls.order = FunctionOrder::THIRD; ls.order = FunctionOrder::THIRD;
let optimizer: LBFGS<f64> = Default::default(); let optimizer: LBFGS<f64> = Default::default();
let result = optimizer.optimize(&f, &df, &x0, &ls); let result = optimizer.optimize(&f, &df, &x0, &ls);
assert!((result.f_x - 0.0).abs() < std::f64::EPSILON); assert!((result.f_x - 0.0).abs() < std::f64::EPSILON);
assert!((result.x.get(0, 0) - 1.0).abs() < 1e-8); assert!((result.x.get(0, 0) - 1.0).abs() < 1e-8);
assert!((result.x.get(0, 1) - 1.0).abs() < 1e-8); assert!((result.x.get(0, 1) - 1.0).abs() < 1e-8);
assert!(result.iterations <= 24); assert!(result.iterations <= 24);
} }
} }
+13 -8
View File
@@ -1,22 +1,27 @@
pub mod lbfgs;
pub mod gradient_descent; pub mod gradient_descent;
pub mod lbfgs;
use std::clone::Clone; use std::clone::Clone;
use std::fmt::Debug; use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::FloatExt;
use crate::optimization::line_search::LineSearchMethod; use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::{F, DF}; use crate::optimization::{DF, F};
pub trait FirstOrderOptimizer<T: FloatExt> { pub trait FirstOrderOptimizer<T: FloatExt> {
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &F<T, X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<T, X>; fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
&self,
f: &F<T, X>,
df: &'a DF<X>,
x0: &X,
ls: &'a LS,
) -> OptimizerResult<T, X>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct OptimizerResult<T: FloatExt, X: Matrix<T>> pub struct OptimizerResult<T: FloatExt, X: Matrix<T>> {
{
pub x: X, pub x: X,
pub f_x: T, pub f_x: T,
pub iterations: usize pub iterations: usize,
} }
+51 -44
View File
@@ -1,14 +1,21 @@
use num_traits::Float;
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
use num_traits::Float;
pub trait LineSearchMethod<T: Float> { pub trait LineSearchMethod<T: Float> {
fn search<'a>(&self, f: &(dyn Fn(T) -> T), df: &(dyn Fn(T) -> T), alpha: T, f0: T, df0: T) -> LineSearchResult<T>; fn search<'a>(
&self,
f: &(dyn Fn(T) -> T),
df: &(dyn Fn(T) -> T),
alpha: T,
f0: T,
df0: T,
) -> LineSearchResult<T>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LineSearchResult<T: Float> { pub struct LineSearchResult<T: Float> {
pub alpha: T, pub alpha: T,
pub f_x: T pub f_x: T,
} }
pub struct Backtracking<T: Float> { pub struct Backtracking<T: Float> {
@@ -17,31 +24,36 @@ pub struct Backtracking<T: Float> {
pub max_infinity_iterations: usize, pub max_infinity_iterations: usize,
pub phi: T, pub phi: T,
pub plo: T, pub plo: T,
pub order: FunctionOrder pub order: FunctionOrder,
} }
impl<T: Float> Default for Backtracking<T> { impl<T: Float> Default for Backtracking<T> {
fn default() -> Self { fn default() -> Self {
Backtracking { Backtracking {
c1: T::from(1e-4).unwrap(), c1: T::from(1e-4).unwrap(),
max_iterations: 1000, max_iterations: 1000,
max_infinity_iterations: (-T::epsilon().log2()).to_usize().unwrap(), max_infinity_iterations: (-T::epsilon().log2()).to_usize().unwrap(),
phi: T::from(0.5).unwrap(), phi: T::from(0.5).unwrap(),
plo: T::from(0.1).unwrap(), plo: T::from(0.1).unwrap(),
order: FunctionOrder::SECOND order: FunctionOrder::SECOND,
} }
} }
} }
impl<T: Float> LineSearchMethod<T> for Backtracking<T> { impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
fn search<'a>(
fn search<'a>(&self, f: &(dyn Fn(T) -> T), _: &(dyn Fn(T) -> T), alpha: T, f0: T, df0: T) -> LineSearchResult<T> { &self,
f: &(dyn Fn(T) -> T),
_: &(dyn Fn(T) -> T),
alpha: T,
f0: T,
df0: T,
) -> LineSearchResult<T> {
let two = T::from(2.).unwrap(); let two = T::from(2.).unwrap();
let three = T::from(3.).unwrap(); let three = T::from(3.).unwrap();
let (mut a1, mut a2) = (alpha, alpha); let (mut a1, mut a2) = (alpha, alpha);
let (mut fx0, mut fx1) = (f0, f(a1)); let (mut fx0, mut fx1) = (f0, f(a1));
let mut iterfinite = 0; let mut iterfinite = 0;
while !fx1.is_finite() && iterfinite < self.max_infinity_iterations { while !fx1.is_finite() && iterfinite < self.max_infinity_iterations {
@@ -52,7 +64,7 @@ impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
fx1 = f(a2); fx1 = f(a2);
} }
let mut iteration = 0; let mut iteration = 0;
while fx1 > f0 + self.c1 * a2 * df0 { while fx1 > f0 + self.c1 * a2 * df0 {
if iteration > self.max_iterations { if iteration > self.max_iterations {
@@ -62,66 +74,61 @@ impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
let a_tmp; let a_tmp;
if self.order == FunctionOrder::SECOND || iteration == 0 { if self.order == FunctionOrder::SECOND || iteration == 0 {
a_tmp = -(df0 * a2.powf(two)) / (two * (fx1 - f0 - df0 * a2))
a_tmp = - (df0 * a2.powf(two)) / (two * (fx1 - f0 - df0*a2))
} else { } else {
let div = T::one() / (a1.powf(two) * a2.powf(two) * (a2 - a1)); let div = T::one() / (a1.powf(two) * a2.powf(two) * (a2 - a1));
let a = (a1.powf(two) * (fx1 - f0 - df0*a2) - a2.powf(two)*(fx0 - f0 - df0*a1))*div; let a = (a1.powf(two) * (fx1 - f0 - df0 * a2)
let b = (-a1.powf(three) * (fx1 - f0 - df0*a2) + a2.powf(three)*(fx0 - f0 - df0*a1))*div; - a2.powf(two) * (fx0 - f0 - df0 * a1))
* div;
let b = (-a1.powf(three) * (fx1 - f0 - df0 * a2)
+ a2.powf(three) * (fx0 - f0 - df0 * a1))
* div;
if (a - T::zero()).powf(two).sqrt() <= T::epsilon() { if (a - T::zero()).powf(two).sqrt() <= T::epsilon() {
a_tmp = df0 / (two * b); a_tmp = df0 / (two * b);
} else { } else {
let d = T::max(b.powf(two) - three * a * df0, T::zero()); let d = T::max(b.powf(two) - three * a * df0, T::zero());
a_tmp = (-b + d.sqrt()) / (three*a); //root of quadratic equation a_tmp = (-b + d.sqrt()) / (three * a); //root of quadratic equation
} }
} }
a1 = a2; a1 = a2;
a2 = T::max(T::min(a_tmp, a2*self.phi), a2*self.plo); a2 = T::max(T::min(a_tmp, a2 * self.phi), a2 * self.plo);
fx0 = fx1; fx0 = fx1;
fx1 = f(a2); fx1 = f(a2);
iteration += 1; iteration += 1;
} }
LineSearchResult { LineSearchResult {
alpha: a2, alpha: a2,
f_x: fx1 f_x: fx1,
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn backtracking() { fn backtracking() {
let f = |x: f64| -> f64 { x.powf(2.) + x };
let f = |x: f64| -> f64 {
x.powf(2.) + x
};
let df = |x: f64| -> f64 { let df = |x: f64| -> f64 { 2. * x + 1. };
2. * x + 1.
};
let ls: Backtracking<f64> = Default::default(); let ls: Backtracking<f64> = Default::default();
let mut x = -3.; let mut x = -3.;
let mut alpha = 1.; let mut alpha = 1.;
for _ in 0..10 { for _ in 0..10 {
let result = ls.search(&f, &df, alpha, f(x), df(x)); let result = ls.search(&f, &df, alpha, f(x), df(x));
alpha = result.alpha; alpha = result.alpha;
x += alpha; x += alpha;
} }
assert!(f(x).abs() < 0.01); assert!(f(x).abs() < 0.01);
} }
} }
+2 -2
View File
@@ -8,5 +8,5 @@ pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
pub enum FunctionOrder { pub enum FunctionOrder {
FIRST, FIRST,
SECOND, SECOND,
THIRD THIRD,
} }
+264 -173
View File
@@ -1,111 +1,112 @@
use std::collections::LinkedList;
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::collections::LinkedList;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix;
use crate::algorithm::sort::quick_sort::QuickArgSort; use crate::algorithm::sort::quick_sort::QuickArgSort;
use crate::linalg::Matrix;
use crate::math::num::FloatExt;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeClassifierParameters { pub struct DecisionTreeClassifierParameters {
pub criterion: SplitCriterion, pub criterion: SplitCriterion,
pub max_depth: Option<u16>, pub max_depth: Option<u16>,
pub min_samples_leaf: usize, pub min_samples_leaf: usize,
pub min_samples_split: usize pub min_samples_split: usize,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeClassifier<T: FloatExt> { pub struct DecisionTreeClassifier<T: FloatExt> {
nodes: Vec<Node<T>>, nodes: Vec<Node<T>>,
parameters: DecisionTreeClassifierParameters, parameters: DecisionTreeClassifierParameters,
num_classes: usize, num_classes: usize,
classes: Vec<T>, classes: Vec<T>,
depth: u16 depth: u16,
} }
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub enum SplitCriterion { pub enum SplitCriterion {
Gini, Gini,
Entropy, Entropy,
ClassificationError ClassificationError,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Node<T: FloatExt> { pub struct Node<T: FloatExt> {
index: usize, index: usize,
output: usize, output: usize,
split_feature: usize, split_feature: usize,
split_value: Option<T>, split_value: Option<T>,
split_score: Option<T>, split_score: Option<T>,
true_child: Option<usize>, true_child: Option<usize>,
false_child: Option<usize>, false_child: Option<usize>,
} }
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> { impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth || if self.depth != other.depth
self.num_classes != other.num_classes || || self.num_classes != other.num_classes
self.nodes.len() != other.nodes.len(){ || self.nodes.len() != other.nodes.len()
return false {
return false;
} else { } else {
for i in 0..self.classes.len() { for i in 0..self.classes.len() {
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() { if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
return false return false;
} }
} }
for i in 0..self.nodes.len() { for i in 0..self.nodes.len() {
if self.nodes[i] != other.nodes[i] { if self.nodes[i] != other.nodes[i] {
return false return false;
} }
} }
return true return true;
} }
} }
} }
impl<T: FloatExt> PartialEq for Node<T> { impl<T: FloatExt> PartialEq for Node<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.output == other.output && self.output == other.output
self.split_feature == other.split_feature && && self.split_feature == other.split_feature
match (self.split_value, other.split_value) { && match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(), (Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true, (None, None) => true,
_ => false, _ => false,
} && }
match (self.split_score, other.split_score) { && match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(), (Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true, (None, None) => true,
_ => false, _ => false,
} }
} }
} }
impl Default for DecisionTreeClassifierParameters { impl Default for DecisionTreeClassifierParameters {
fn default() -> Self { fn default() -> Self {
DecisionTreeClassifierParameters { DecisionTreeClassifierParameters {
criterion: SplitCriterion::Gini, criterion: SplitCriterion::Gini,
max_depth: None, max_depth: None,
min_samples_leaf: 1, min_samples_leaf: 1,
min_samples_split: 2 min_samples_split: 2,
} }
} }
} }
impl<T: FloatExt> Node<T> { impl<T: FloatExt> Node<T> {
fn new(index: usize, output: usize) -> Self { fn new(index: usize, output: usize) -> Self {
Node { Node {
index: index, index: index,
output: output, output: output,
split_feature: 0, split_feature: 0,
split_value: Option::None, split_value: Option::None,
split_score: Option::None, split_score: Option::None,
true_child: Option::None, true_child: Option::None,
false_child: Option::None false_child: Option::None,
} }
} }
} }
struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> { struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
@@ -113,11 +114,11 @@ struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
y: &'a Vec<usize>, y: &'a Vec<usize>,
node: usize, node: usize,
samples: Vec<usize>, samples: Vec<usize>,
order: &'a Vec<Vec<usize>>, order: &'a Vec<Vec<usize>>,
true_child_output: usize, true_child_output: usize,
false_child_output: usize, false_child_output: usize,
level: u16, level: u16,
phantom: PhantomData<&'a T> phantom: PhantomData<&'a T>,
} }
fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> T { fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> T {
@@ -131,7 +132,7 @@ fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usiz
let p = T::from(count[i]).unwrap() / T::from(n).unwrap(); let p = T::from(count[i]).unwrap() / T::from(n).unwrap();
impurity = impurity - p * p; impurity = impurity - p * p;
} }
} }
} }
SplitCriterion::Entropy => { SplitCriterion::Entropy => {
@@ -149,15 +150,21 @@ fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usiz
} }
} }
impurity = (T::one() - impurity).abs(); impurity = (T::one() - impurity).abs();
} }
} }
return impurity; return impurity;
} }
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> { impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
fn new(
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self { node_id: usize,
samples: Vec<usize>,
order: &'a Vec<Vec<usize>>,
x: &'a M,
y: &'a Vec<usize>,
level: u16,
) -> Self {
NodeVisitor { NodeVisitor {
x: x, x: x,
y: y, y: y,
@@ -167,10 +174,9 @@ impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
true_child_output: 0, true_child_output: 0,
false_child_output: 0, false_child_output: 0,
level: level, level: level,
phantom: PhantomData phantom: PhantomData,
} }
} }
} }
pub(in crate) fn which_max(x: &Vec<usize>) -> usize { pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
@@ -188,19 +194,28 @@ pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
} }
impl<T: FloatExt> DecisionTreeClassifier<T> { impl<T: FloatExt> DecisionTreeClassifier<T> {
pub fn fit<M: Matrix<T>>(
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier<T> { x: &M,
y: &M::RowVector,
parameters: DecisionTreeClassifierParameters,
) -> DecisionTreeClassifier<T> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
} }
pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier<T> { pub fn fit_weak_learner<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
samples: Vec<usize>,
mtry: usize,
parameters: DecisionTreeClassifierParameters,
) -> DecisionTreeClassifier<T> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
let classes = y_m.unique(); let classes = y_m.unique();
let k = classes.len(); let k = classes.len();
if k < 2 { if k < 2 {
panic!("Incorrect number of classes: {}. Should be >= 2.", k); panic!("Incorrect number of classes: {}. Should be >= 2.", k);
} }
@@ -208,31 +223,31 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
let mut yi: Vec<usize> = vec![0; y_ncols]; let mut yi: Vec<usize> = vec![0; y_ncols];
for i in 0..y_ncols { for i in 0..y_ncols {
let yc = y_m.get(0, i); let yc = y_m.get(0, i);
yi[i] = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap();
} }
let mut nodes: Vec<Node<T>> = Vec::new(); let mut nodes: Vec<Node<T>> = Vec::new();
let mut count = vec![0; k]; let mut count = vec![0; k];
for i in 0..y_ncols { for i in 0..y_ncols {
count[yi[i]] += samples[i]; count[yi[i]] += samples[i];
} }
let root = Node::new(0, which_max(&count)); let root = Node::new(0, which_max(&count));
nodes.push(root); nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new(); let mut order: Vec<Vec<usize>> = Vec::new();
for i in 0..num_attributes { for i in 0..num_attributes {
order.push(x.get_col_as_vec(i).quick_argsort()); order.push(x.get_col_as_vec(i).quick_argsort());
} }
let mut tree = DecisionTreeClassifier{ let mut tree = DecisionTreeClassifier {
nodes: nodes, nodes: nodes,
parameters: parameters, parameters: parameters,
num_classes: k, num_classes: k,
classes: classes, classes: classes,
depth: 0 depth: 0,
}; };
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &yi, 1); let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &yi, 1);
@@ -243,12 +258,12 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
} }
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() { match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue,), Some(node) => tree.split(node, mtry, &mut visitor_queue),
None => break None => break,
}; };
} }
tree tree
} }
@@ -270,7 +285,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
let mut queue: LinkedList<usize> = LinkedList::new(); let mut queue: LinkedList<usize> = LinkedList::new();
queue.push_back(0); queue.push_back(0);
while !queue.is_empty() { while !queue.is_empty() {
match queue.pop_front() { match queue.pop_front() {
Some(node_id) => { Some(node_id) => {
@@ -284,18 +299,20 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
queue.push_back(node.false_child.unwrap()); queue.push_back(node.false_child.unwrap());
} }
} }
}, }
None => break None => break,
}; };
} }
return result return result;
}
}
fn find_best_cutoff<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, mtry: usize) -> bool {
let (n_rows, n_attr) = visitor.x.shape(); fn find_best_cutoff<M: Matrix<T>>(
&mut self,
visitor: &mut NodeVisitor<T, M>,
mtry: usize,
) -> bool {
let (n_rows, n_attr) = visitor.x.shape();
let mut label = Option::None; let mut label = Option::None;
let mut is_pure = true; let mut is_pure = true;
@@ -309,17 +326,17 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
} }
} }
} }
if is_pure { if is_pure {
return false; return false;
} }
let n = visitor.samples.iter().sum(); let n = visitor.samples.iter().sum();
if n <= self.parameters.min_samples_split { if n <= self.parameters.min_samples_split {
return false; return false;
} }
let mut count = vec![0; self.num_classes]; let mut count = vec![0; self.num_classes];
let mut false_count = vec![0; self.num_classes]; let mut false_count = vec![0; self.num_classes];
for i in 0..n_rows { for i in 0..n_rows {
@@ -329,25 +346,38 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
} }
let parent_impurity = impurity(&self.parameters.criterion, &count, n); let parent_impurity = impurity(&self.parameters.criterion, &count, n);
let mut variables = vec![0; n_attr]; let mut variables = vec![0; n_attr];
for i in 0..n_attr { for i in 0..n_attr {
variables[i] = i; variables[i] = i;
} }
for j in 0..mtry { for j in 0..mtry {
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]); self.find_best_split(
} visitor,
n,
&count,
&mut false_count,
parent_impurity,
variables[j],
);
}
self.nodes[visitor.node].split_score != Option::None self.nodes[visitor.node].split_score != Option::None
}
} fn find_best_split<M: Matrix<T>>(
&mut self,
fn find_best_split<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, n: usize, count: &Vec<usize>, false_count: &mut Vec<usize>, parent_impurity: T, j: usize){ visitor: &mut NodeVisitor<T, M>,
n: usize,
count: &Vec<usize>,
false_count: &mut Vec<usize>,
parent_impurity: T,
j: usize,
) {
let mut true_count = vec![0; self.num_classes]; let mut true_count = vec![0; self.num_classes];
let mut prevx = T::nan(); let mut prevx = T::nan();
let mut prevy = 0; let mut prevy = 0;
for i in visitor.order[j].iter() { for i in visitor.order[j].iter() {
if visitor.samples[*i] > 0 { if visitor.samples[*i] > 0 {
@@ -360,7 +390,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
let tc = true_count.iter().sum(); let tc = true_count.iter().sum();
let fc = n - tc; let fc = n - tc;
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf { if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
prevy = visitor.y[*i]; prevy = visitor.y[*i];
@@ -373,12 +403,19 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
} }
let true_label = which_max(&true_count); let true_label = which_max(&true_count);
let false_label = which_max(false_count); let false_label = which_max(false_count);
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc); let gain = parent_impurity
- T::from(tc).unwrap() / T::from(n).unwrap()
* impurity(&self.parameters.criterion, &true_count, tc)
- T::from(fc).unwrap() / T::from(n).unwrap()
* impurity(&self.parameters.criterion, &false_count, fc);
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() { if self.nodes[visitor.node].split_score == Option::None
|| gain > self.nodes[visitor.node].split_score.unwrap()
{
self.nodes[visitor.node].split_feature = j; self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two()); self.nodes[visitor.node].split_value =
Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
self.nodes[visitor.node].split_score = Option::Some(gain); self.nodes[visitor.node].split_score = Option::Some(gain);
visitor.true_child_output = true_label; visitor.true_child_output = true_label;
visitor.false_child_output = false_label; visitor.false_child_output = false_label;
@@ -389,22 +426,28 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
true_count[visitor.y[*i]] += visitor.samples[*i]; true_count[visitor.y[*i]] += visitor.samples[*i];
} }
} }
} }
fn split<'a, M: Matrix<T>>(&mut self, mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>) -> bool { fn split<'a, M: Matrix<T>>(
&mut self,
mut visitor: NodeVisitor<'a, T, M>,
mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
) -> bool {
let (n, _) = visitor.x.shape(); let (n, _) = visitor.x.shape();
let mut tc = 0; let mut tc = 0;
let mut fc = 0; let mut fc = 0;
let mut true_samples: Vec<usize> = vec![0; n]; let mut true_samples: Vec<usize> = vec![0; n];
for i in 0..n { for i in 0..n {
if visitor.samples[i] > 0 { if visitor.samples[i] > 0 {
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) { if visitor.x.get(i, self.nodes[visitor.node].split_feature)
<= self.nodes[visitor.node].split_value.unwrap_or(T::nan())
{
true_samples[i] = visitor.samples[i]; true_samples[i] = visitor.samples[i];
tc += true_samples[i]; tc += true_samples[i];
visitor.samples[i] = 0; visitor.samples[i] = 0;
} else { } else {
fc += visitor.samples[i]; fc += visitor.samples[i];
} }
} }
@@ -415,50 +458,73 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
self.nodes[visitor.node].split_value = Option::None; self.nodes[visitor.node].split_value = Option::None;
self.nodes[visitor.node].split_score = Option::None; self.nodes[visitor.node].split_score = Option::None;
return false; return false;
} }
let true_child_idx = self.nodes.len(); let true_child_idx = self.nodes.len();
self.nodes.push(Node::new(true_child_idx, visitor.true_child_output)); self.nodes
.push(Node::new(true_child_idx, visitor.true_child_output));
let false_child_idx = self.nodes.len(); let false_child_idx = self.nodes.len();
self.nodes.push(Node::new(false_child_idx, visitor.false_child_output)); self.nodes
.push(Node::new(false_child_idx, visitor.false_child_output));
self.nodes[visitor.node].true_child = Some(true_child_idx); self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx); self.nodes[visitor.node].false_child = Some(false_child_idx);
self.depth = u16::max(self.depth, visitor.level + 1); self.depth = u16::max(self.depth, visitor.level + 1);
let mut true_visitor = NodeVisitor::<T, M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); let mut true_visitor = NodeVisitor::<T, M>::new(
true_child_idx,
true_samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut true_visitor, mtry) { if self.find_best_cutoff(&mut true_visitor, mtry) {
visitor_queue.push_back(true_visitor); visitor_queue.push_back(true_visitor);
} }
let mut false_visitor = NodeVisitor::<T, M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); let mut false_visitor = NodeVisitor::<T, M>::new(
false_child_idx,
visitor.samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut false_visitor, mtry) { if self.find_best_cutoff(&mut false_visitor, mtry) {
visitor_queue.push_back(false_visitor); visitor_queue.push_back(false_visitor);
} }
true true
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test] #[test]
fn gini_impurity() { fn gini_impurity() {
assert!((impurity::<f64>(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs() < std::f64::EPSILON); assert!(
assert!((impurity::<f64>(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs() < std::f64::EPSILON); (impurity::<f64>(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs()
assert!((impurity::<f64>(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs() < std::f64::EPSILON); < std::f64::EPSILON
);
assert!(
(impurity::<f64>(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs()
< std::f64::EPSILON
);
assert!(
(impurity::<f64>(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs()
< std::f64::EPSILON
);
} }
#[test] #[test]
fn fit_predict_iris() { fn fit_predict_iris() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2],
@@ -479,75 +545,100 @@ mod tests {
&[6.3, 3.3, 4.7, 1.6], &[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4]]); &[5.2, 2.7, 3.9, 1.4],
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; ]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)); assert_eq!(
y,
assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1, min_samples_split: 2}).depth); DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)
);
assert_eq!(
3,
DecisionTreeClassifier::fit(
&x,
&y,
DecisionTreeClassifierParameters {
criterion: SplitCriterion::Entropy,
max_depth: Some(3),
min_samples_leaf: 1,
min_samples_split: 2
}
)
.depth
);
} }
#[test] #[test]
fn fit_predict_baloons() { fn fit_predict_baloons() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[1.,1.,1.,0.], &[1., 1., 1., 0.],
&[1.,1.,1.,0.], &[1., 1., 1., 0.],
&[1.,1.,1.,1.], &[1., 1., 1., 1.],
&[1.,1.,0.,0.], &[1., 1., 0., 0.],
&[1.,1.,0.,1.], &[1., 1., 0., 1.],
&[1.,0.,1.,0.], &[1., 0., 1., 0.],
&[1.,0.,1.,0.], &[1., 0., 1., 0.],
&[1.,0.,1.,1.], &[1., 0., 1., 1.],
&[1.,0.,0.,0.], &[1., 0., 0., 0.],
&[1.,0.,0.,1.], &[1., 0., 0., 1.],
&[0.,1.,1.,0.], &[0., 1., 1., 0.],
&[0.,1.,1.,0.], &[0., 1., 1., 0.],
&[0.,1.,1.,1.], &[0., 1., 1., 1.],
&[0.,1.,0.,0.], &[0., 1., 0., 0.],
&[0.,1.,0.,1.], &[0., 1., 0., 1.],
&[0.,0.,1.,0.], &[0., 0., 1., 0.],
&[0.,0.,1.,0.], &[0., 0., 1., 0.],
&[0.,0.,1.,1.], &[0., 0., 1., 1.],
&[0.,0.,0.,0.], &[0., 0., 0., 0.],
&[0.,0.,0.,1.]]); &[0., 0., 0., 1.],
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.]; ]);
let y = vec![
1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
];
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)); assert_eq!(
y,
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)
);
} }
#[test] #[test]
fn serde() { fn serde() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[1.,1.,1.,0.], &[1., 1., 1., 0.],
&[1.,1.,1.,0.], &[1., 1., 1., 0.],
&[1.,1.,1.,1.], &[1., 1., 1., 1.],
&[1.,1.,0.,0.], &[1., 1., 0., 0.],
&[1.,1.,0.,1.], &[1., 1., 0., 1.],
&[1.,0.,1.,0.], &[1., 0., 1., 0.],
&[1.,0.,1.,0.], &[1., 0., 1., 0.],
&[1.,0.,1.,1.], &[1., 0., 1., 1.],
&[1.,0.,0.,0.], &[1., 0., 0., 0.],
&[1.,0.,0.,1.], &[1., 0., 0., 1.],
&[0.,1.,1.,0.], &[0., 1., 1., 0.],
&[0.,1.,1.,0.], &[0., 1., 1., 0.],
&[0.,1.,1.,1.], &[0., 1., 1., 1.],
&[0.,1.,0.,0.], &[0., 1., 0., 0.],
&[0.,1.,0.,1.], &[0., 1., 0., 1.],
&[0.,0.,1.,0.], &[0., 0., 1., 0.],
&[0.,0.,1.,0.], &[0., 0., 1., 0.],
&[0.,0.,1.,1.], &[0., 0., 1., 1.],
&[0.,0.,0.,0.], &[0., 0., 0., 0.],
&[0.,0.,0.,1.]]); &[0., 0., 0., 1.],
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.]; ]);
let y = vec![
1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()); let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
let deserialized_tree: DecisionTreeClassifier<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap(); let deserialized_tree: DecisionTreeClassifier<f64> =
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
assert_eq!(tree, deserialized_tree); assert_eq!(tree, deserialized_tree);
} }
} }
+247 -167
View File
@@ -1,91 +1,90 @@
use std::collections::LinkedList;
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use std::collections::LinkedList;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix;
use crate::algorithm::sort::quick_sort::QuickArgSort; use crate::algorithm::sort::quick_sort::QuickArgSort;
use crate::linalg::Matrix;
use crate::math::num::FloatExt;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeRegressorParameters { pub struct DecisionTreeRegressorParameters {
pub max_depth: Option<u16>, pub max_depth: Option<u16>,
pub min_samples_leaf: usize, pub min_samples_leaf: usize,
pub min_samples_split: usize pub min_samples_split: usize,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeRegressor<T: FloatExt> { pub struct DecisionTreeRegressor<T: FloatExt> {
nodes: Vec<Node<T>>, nodes: Vec<Node<T>>,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
depth: u16 depth: u16,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Node<T: FloatExt> { pub struct Node<T: FloatExt> {
index: usize, index: usize,
output: T, output: T,
split_feature: usize, split_feature: usize,
split_value: Option<T>, split_value: Option<T>,
split_score: Option<T>, split_score: Option<T>,
true_child: Option<usize>, true_child: Option<usize>,
false_child: Option<usize>, false_child: Option<usize>,
} }
impl Default for DecisionTreeRegressorParameters { impl Default for DecisionTreeRegressorParameters {
fn default() -> Self { fn default() -> Self {
DecisionTreeRegressorParameters { DecisionTreeRegressorParameters {
max_depth: None, max_depth: None,
min_samples_leaf: 1, min_samples_leaf: 1,
min_samples_split: 2 min_samples_split: 2,
} }
} }
} }
impl<T: FloatExt> Node<T> { impl<T: FloatExt> Node<T> {
fn new(index: usize, output: T) -> Self { fn new(index: usize, output: T) -> Self {
Node { Node {
index: index, index: index,
output: output, output: output,
split_feature: 0, split_feature: 0,
split_value: Option::None, split_value: Option::None,
split_score: Option::None, split_score: Option::None,
true_child: Option::None, true_child: Option::None,
false_child: Option::None false_child: Option::None,
} }
}
}
impl<T: FloatExt> PartialEq for Node<T> {
fn eq(&self, other: &Self) -> bool {
(self.output - other.output).abs() < T::epsilon() &&
self.split_feature == other.split_feature &&
match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true,
_ => false,
} &&
match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true,
_ => false,
}
} }
} }
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> { impl<T: FloatExt> PartialEq for Node<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth || self.nodes.len() != other.nodes.len(){ (self.output - other.output).abs() < T::epsilon()
return false && self.split_feature == other.split_feature
} else { && match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true,
_ => false,
}
&& match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true,
_ => false,
}
}
}
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth || self.nodes.len() != other.nodes.len() {
return false;
} else {
for i in 0..self.nodes.len() { for i in 0..self.nodes.len() {
if self.nodes[i] != other.nodes[i] { if self.nodes[i] != other.nodes[i] {
return false return false;
} }
} }
return true return true;
} }
} }
} }
@@ -95,15 +94,21 @@ struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
y: &'a M, y: &'a M,
node: usize, node: usize,
samples: Vec<usize>, samples: Vec<usize>,
order: &'a Vec<Vec<usize>>, order: &'a Vec<Vec<usize>>,
true_child_output: T, true_child_output: T,
false_child_output: T, false_child_output: T,
level: u16 level: u16,
} }
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> { impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
fn new(
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a M, level: u16) -> Self { node_id: usize,
samples: Vec<usize>,
order: &'a Vec<Vec<usize>>,
x: &'a M,
y: &'a M,
level: u16,
) -> Self {
NodeVisitor { NodeVisitor {
x: x, x: x,
y: y, y: y,
@@ -112,33 +117,41 @@ impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
order: order, order: order,
true_child_output: T::zero(), true_child_output: T::zero(),
false_child_output: T::zero(), false_child_output: T::zero(),
level: level level: level,
} }
} }
} }
impl<T: FloatExt> DecisionTreeRegressor<T> { impl<T: FloatExt> DecisionTreeRegressor<T> {
pub fn fit<M: Matrix<T>>(
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor<T> { x: &M,
y: &M::RowVector,
parameters: DecisionTreeRegressorParameters,
) -> DecisionTreeRegressor<T> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
} }
pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor<T> { pub fn fit_weak_learner<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
samples: Vec<usize>,
mtry: usize,
parameters: DecisionTreeRegressorParameters,
) -> DecisionTreeRegressor<T> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
let classes = y_m.unique(); let classes = y_m.unique();
let k = classes.len(); let k = classes.len();
if k < 2 { if k < 2 {
panic!("Incorrect number of classes: {}. Should be >= 2.", k); panic!("Incorrect number of classes: {}. Should be >= 2.", k);
} }
let mut nodes: Vec<Node<T>> = Vec::new();
let mut nodes: Vec<Node<T>> = Vec::new();
let mut n = 0; let mut n = 0;
let mut sum = T::zero(); let mut sum = T::zero();
for i in 0..y_ncols { for i in 0..y_ncols {
@@ -146,18 +159,18 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
sum = sum + T::from(samples[i]).unwrap() * y_m.get(0, i); sum = sum + T::from(samples[i]).unwrap() * y_m.get(0, i);
} }
let root = Node::new(0, sum / T::from(n).unwrap()); let root = Node::new(0, sum / T::from(n).unwrap());
nodes.push(root); nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new(); let mut order: Vec<Vec<usize>> = Vec::new();
for i in 0..num_attributes { for i in 0..num_attributes {
order.push(x.get_col_as_vec(i).quick_argsort()); order.push(x.get_col_as_vec(i).quick_argsort());
} }
let mut tree = DecisionTreeRegressor{ let mut tree = DecisionTreeRegressor {
nodes: nodes, nodes: nodes,
parameters: parameters, parameters: parameters,
depth: 0 depth: 0,
}; };
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &y_m, 1); let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &y_m, 1);
@@ -168,12 +181,12 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
} }
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() { match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue), Some(node) => tree.split(node, mtry, &mut visitor_queue),
None => break None => break,
}; };
} }
tree tree
} }
@@ -195,7 +208,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
let mut queue: LinkedList<usize> = LinkedList::new(); let mut queue: LinkedList<usize> = LinkedList::new();
queue.push_back(0); queue.push_back(0);
while !queue.is_empty() { while !queue.is_empty() {
match queue.pop_front() { match queue.pop_front() {
Some(node_id) => { Some(node_id) => {
@@ -209,100 +222,123 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
queue.push_back(node.false_child.unwrap()); queue.push_back(node.false_child.unwrap());
} }
} }
}, }
None => break None => break,
}; };
} }
return result return result;
}
}
fn find_best_cutoff<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, mtry: usize) -> bool {
let (_, n_attr) = visitor.x.shape(); fn find_best_cutoff<M: Matrix<T>>(
&mut self,
visitor: &mut NodeVisitor<T, M>,
mtry: usize,
) -> bool {
let (_, n_attr) = visitor.x.shape();
let n: usize = visitor.samples.iter().sum(); let n: usize = visitor.samples.iter().sum();
if n < self.parameters.min_samples_split { if n < self.parameters.min_samples_split {
return false; return false;
} }
let sum = self.nodes[visitor.node].output * T::from(n).unwrap(); let sum = self.nodes[visitor.node].output * T::from(n).unwrap();
let mut variables = vec![0; n_attr]; let mut variables = vec![0; n_attr];
for i in 0..n_attr { for i in 0..n_attr {
variables[i] = i; variables[i] = i;
} }
let parent_gain = T::from(n).unwrap() * self.nodes[visitor.node].output * self.nodes[visitor.node].output; let parent_gain =
T::from(n).unwrap() * self.nodes[visitor.node].output * self.nodes[visitor.node].output;
for j in 0..mtry { for j in 0..mtry {
self.find_best_split(visitor, n, sum, parent_gain, variables[j]); self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
} }
self.nodes[visitor.node].split_score != Option::None self.nodes[visitor.node].split_score != Option::None
}
} fn find_best_split<M: Matrix<T>>(
&mut self,
fn find_best_split<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, n: usize, sum: T, parent_gain: T, j: usize){ visitor: &mut NodeVisitor<T, M>,
n: usize,
sum: T,
parent_gain: T,
j: usize,
) {
let mut true_sum = T::zero(); let mut true_sum = T::zero();
let mut true_count = 0; let mut true_count = 0;
let mut prevx = T::nan(); let mut prevx = T::nan();
for i in visitor.order[j].iter() { for i in visitor.order[j].iter() {
if visitor.samples[*i] > 0 { if visitor.samples[*i] > 0 {
if prevx.is_nan() || visitor.x.get(*i, j) == prevx { if prevx.is_nan() || visitor.x.get(*i, j) == prevx {
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
true_count += visitor.samples[*i]; true_count += visitor.samples[*i];
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i); true_sum =
true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
continue; continue;
} }
let false_count = n - true_count; let false_count = n - true_count;
if true_count < self.parameters.min_samples_leaf || false_count < self.parameters.min_samples_leaf { if true_count < self.parameters.min_samples_leaf
|| false_count < self.parameters.min_samples_leaf
{
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
true_count += visitor.samples[*i]; true_count += visitor.samples[*i];
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i); true_sum =
true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
continue; continue;
} }
let true_mean = true_sum / T::from(true_count).unwrap(); let true_mean = true_sum / T::from(true_count).unwrap();
let false_mean = (sum - true_sum) / T::from(false_count).unwrap(); let false_mean = (sum - true_sum) / T::from(false_count).unwrap();
let gain = (T::from(true_count).unwrap() * true_mean * true_mean + T::from(false_count).unwrap() * false_mean * false_mean) - parent_gain; let gain = (T::from(true_count).unwrap() * true_mean * true_mean
+ T::from(false_count).unwrap() * false_mean * false_mean)
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() { - parent_gain;
if self.nodes[visitor.node].split_score == Option::None
|| gain > self.nodes[visitor.node].split_score.unwrap()
{
self.nodes[visitor.node].split_feature = j; self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two()); self.nodes[visitor.node].split_value =
Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
self.nodes[visitor.node].split_score = Option::Some(gain); self.nodes[visitor.node].split_score = Option::Some(gain);
visitor.true_child_output = true_mean; visitor.true_child_output = true_mean;
visitor.false_child_output = false_mean; visitor.false_child_output = false_mean;
} }
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i); true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
true_count += visitor.samples[*i]; true_count += visitor.samples[*i];
} }
} }
} }
fn split<'a, M: Matrix<T>>(&mut self, mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>) -> bool { fn split<'a, M: Matrix<T>>(
&mut self,
mut visitor: NodeVisitor<'a, T, M>,
mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
) -> bool {
let (n, _) = visitor.x.shape(); let (n, _) = visitor.x.shape();
let mut tc = 0; let mut tc = 0;
let mut fc = 0; let mut fc = 0;
let mut true_samples: Vec<usize> = vec![0; n]; let mut true_samples: Vec<usize> = vec![0; n];
for i in 0..n { for i in 0..n {
if visitor.samples[i] > 0 { if visitor.samples[i] > 0 {
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) { if visitor.x.get(i, self.nodes[visitor.node].split_feature)
<= self.nodes[visitor.node].split_value.unwrap_or(T::nan())
{
true_samples[i] = visitor.samples[i]; true_samples[i] = visitor.samples[i];
tc += true_samples[i]; tc += true_samples[i];
visitor.samples[i] = 0; visitor.samples[i] = 0;
} else { } else {
fc += visitor.samples[i]; fc += visitor.samples[i];
} }
} }
@@ -313,111 +349,155 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
self.nodes[visitor.node].split_value = Option::None; self.nodes[visitor.node].split_value = Option::None;
self.nodes[visitor.node].split_score = Option::None; self.nodes[visitor.node].split_score = Option::None;
return false; return false;
} }
let true_child_idx = self.nodes.len(); let true_child_idx = self.nodes.len();
self.nodes.push(Node::new(true_child_idx, visitor.true_child_output)); self.nodes
.push(Node::new(true_child_idx, visitor.true_child_output));
let false_child_idx = self.nodes.len(); let false_child_idx = self.nodes.len();
self.nodes.push(Node::new(false_child_idx, visitor.false_child_output)); self.nodes
.push(Node::new(false_child_idx, visitor.false_child_output));
self.nodes[visitor.node].true_child = Some(true_child_idx); self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx); self.nodes[visitor.node].false_child = Some(false_child_idx);
self.depth = u16::max(self.depth, visitor.level + 1); self.depth = u16::max(self.depth, visitor.level + 1);
let mut true_visitor = NodeVisitor::<T, M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); let mut true_visitor = NodeVisitor::<T, M>::new(
true_child_idx,
true_samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut true_visitor, mtry) { if self.find_best_cutoff(&mut true_visitor, mtry) {
visitor_queue.push_back(true_visitor); visitor_queue.push_back(true_visitor);
} }
let mut false_visitor = NodeVisitor::<T, M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); let mut false_visitor = NodeVisitor::<T, M>::new(
false_child_idx,
visitor.samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut false_visitor, mtry) { if self.find_best_cutoff(&mut false_visitor, mtry) {
visitor_queue.push_back(false_visitor); visitor_queue.push_back(false_visitor);
} }
true true
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test] #[test]
fn fit_longley() { fn fit_longley() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323], &[234.289, 235.6, 159., 107.608, 1947., 60.323],
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122], &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171], &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187], &[284.599, 335.1, 165., 110.929, 1950., 61.187],
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221], &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639], &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989], &[365.385, 187., 354.7, 115.094, 1953., 64.989],
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761], &[363.112, 357.8, 335., 116.219, 1954., 63.761],
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019], &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857], &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169], &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513], &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655], &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; ]);
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x); let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - y[i]).abs() < 0.1); assert!((y_hat[i] - y[i]).abs() < 0.1);
} }
let expected_y = vec![87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85, 114.85, 114.85, 114.85]; let expected_y = vec![
let y_hat = DecisionTreeRegressor::fit(&x, &y, DecisionTreeRegressorParameters{max_depth: Option::None, min_samples_leaf: 2, min_samples_split: 6}).predict(&x); 87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85,
114.85, 114.85, 114.85,
];
let y_hat = DecisionTreeRegressor::fit(
&x,
&y,
DecisionTreeRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 2,
min_samples_split: 6,
},
)
.predict(&x);
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 0.1); assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
} }
let expected_y = vec![83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4, 113.4, 116.30, 116.30]; let expected_y = vec![
let y_hat = DecisionTreeRegressor::fit(&x, &y, DecisionTreeRegressorParameters{max_depth: Option::None, min_samples_leaf: 1, min_samples_split: 3}).predict(&x); 83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4,
113.4, 116.30, 116.30,
];
let y_hat = DecisionTreeRegressor::fit(
&x,
&y,
DecisionTreeRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 3,
},
)
.predict(&x);
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 0.1); assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
} }
} }
#[test] #[test]
fn serde() { fn serde() {
let x = DenseMatrix::from_array(&[ let x = DenseMatrix::from_array(&[
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323], &[234.289, 235.6, 159., 107.608, 1947., 60.323],
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122], &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171], &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187], &[284.599, 335.1, 165., 110.929, 1950., 61.187],
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221], &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639], &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989], &[365.385, 187., 354.7, 115.094, 1953., 64.989],
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761], &[363.112, 357.8, 335., 116.219, 1954., 63.761],
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019], &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857], &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169], &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513], &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655], &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; ]);
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default()); let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
let deserialized_tree: DecisionTreeRegressor<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap(); let deserialized_tree: DecisionTreeRegressor<f64> =
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
assert_eq!(tree, deserialized_tree); assert_eq!(tree, deserialized_tree);
} }
}
}
+1 -1
View File
@@ -1,2 +1,2 @@
pub mod decision_tree_classifier;
pub mod decision_tree_regressor; pub mod decision_tree_regressor;
pub mod decision_tree_classifier;