feat: extends interface of Matrix to support for broad range of types

This commit is contained in:
Volodymyr Orlov
2020-03-26 15:28:26 -07:00
parent 84ffd331cd
commit 02b85415d9
27 changed files with 1021 additions and 868 deletions
+47 -44
View File
@@ -1,42 +1,45 @@
use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::distance::euclidian; use crate::math::distance::euclidian;
#[derive(Debug)] #[derive(Debug)]
pub struct BBDTree { pub struct BBDTree<T: FloatExt + Debug> {
nodes: Vec<BBDTreeNode>, nodes: Vec<BBDTreeNode<T>>,
index: Vec<usize>, index: Vec<usize>,
root: usize root: usize
} }
#[derive(Debug)] #[derive(Debug)]
struct BBDTreeNode { struct BBDTreeNode<T: FloatExt + Debug> {
count: usize, count: usize,
index: usize, index: usize,
center: Vec<f64>, center: Vec<T>,
radius: Vec<f64>, radius: Vec<T>,
sum: Vec<f64>, sum: Vec<T>,
cost: f64, cost: T,
lower: Option<usize>, lower: Option<usize>,
upper: Option<usize> upper: Option<usize>
} }
impl BBDTreeNode { impl<T: FloatExt + Debug> BBDTreeNode<T> {
fn new(d: usize) -> BBDTreeNode { fn new(d: usize) -> BBDTreeNode<T> {
BBDTreeNode { BBDTreeNode {
count: 0, count: 0,
index: 0, index: 0,
center: vec![0f64; d], center: vec![T::zero(); d],
radius: vec![0f64; d], radius: vec![T::zero(); d],
sum: vec![0f64; d], sum: vec![T::zero(); d],
cost: 0f64, cost: T::zero(),
lower: Option::None, lower: Option::None,
upper: Option::None upper: Option::None
} }
} }
} }
impl BBDTree { impl<T: FloatExt + Debug> BBDTree<T> {
pub fn new<M: Matrix>(data: &M) -> BBDTree { pub fn new<M: Matrix<T>>(data: &M) -> BBDTree<T> {
let nodes = Vec::new(); let nodes = Vec::new();
let (n, _) = data.shape(); let (n, _) = data.shape();
@@ -59,20 +62,20 @@ impl BBDTree {
tree tree
} }
pub(in crate) fn clustering(&self, centroids: &Vec<Vec<f64>>, sums: &mut Vec<Vec<f64>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> f64 { pub(in crate) fn clustering(&self, centroids: &Vec<Vec<T>>, sums: &mut Vec<Vec<T>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> T {
let k = centroids.len(); let k = centroids.len();
counts.iter_mut().for_each(|x| *x = 0); counts.iter_mut().for_each(|x| *x = 0);
let mut candidates = vec![0; k]; let mut candidates = vec![0; k];
for i in 0..k { for i in 0..k {
candidates[i] = i; candidates[i] = i;
sums[i].iter_mut().for_each(|x| *x = 0f64); sums[i].iter_mut().for_each(|x| *x = T::zero());
} }
self.filter(self.root, centroids, &candidates, k, sums, counts, membership) self.filter(self.root, centroids, &candidates, k, sums, counts, membership)
} }
fn filter(&self, node: usize, centroids: &Vec<Vec<f64>>, candidates: &Vec<usize>, k: usize, sums: &mut Vec<Vec<f64>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> f64{ fn filter(&self, node: usize, centroids: &Vec<Vec<T>>, candidates: &Vec<usize>, k: usize, sums: &mut Vec<Vec<T>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> T{
let d = centroids[0].len(); let d = centroids[0].len();
// Determine which mean the node mean is closest to // Determine which mean the node mean is closest to
@@ -109,7 +112,7 @@ impl BBDTree {
// Assigns all data within this node to a single mean // Assigns all data within this node to a single mean
for i in 0..d { for i in 0..d {
sums[closest][i] += self.nodes[node].sum[i]; sums[closest][i] = sums[closest][i] + self.nodes[node].sum[i];
} }
counts[closest] += self.nodes[node].count; counts[closest] += self.nodes[node].count;
@@ -123,7 +126,7 @@ impl BBDTree {
} }
fn prune(center: &Vec<f64>, radius: &Vec<f64>, centroids: &Vec<Vec<f64>>, best_index: usize, test_index: usize) -> bool { fn prune(center: &Vec<T>, radius: &Vec<T>, centroids: &Vec<Vec<T>>, best_index: usize, test_index: usize) -> bool {
if best_index == test_index { if best_index == test_index {
return false; return false;
} }
@@ -132,22 +135,22 @@ impl BBDTree {
let best = &centroids[best_index]; let best = &centroids[best_index];
let test = &centroids[test_index]; let test = &centroids[test_index];
let mut lhs = 0f64; let mut lhs = T::zero();
let mut rhs = 0f64; let mut rhs = T::zero();
for i in 0..d { for i in 0..d {
let diff = test[i] - best[i]; let diff = test[i] - best[i];
lhs += diff * diff; lhs = lhs + diff * diff;
if diff > 0f64 { if diff > T::zero() {
rhs += (center[i] + radius[i] - best[i]) * diff; rhs = rhs + (center[i] + radius[i] - best[i]) * diff;
} else { } else {
rhs += (center[i] - radius[i] - best[i]) * diff; rhs = rhs + (center[i] - radius[i] - best[i]) * diff;
} }
} }
return lhs >= 2f64 * rhs; return lhs >= T::two() * rhs;
} }
fn build_node<M: Matrix>(&mut self, data: &M, begin: usize, end: usize) -> usize { fn build_node<M: Matrix<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
let (_, d) = data.shape(); let (_, d) = data.shape();
// Allocate the node // Allocate the node
@@ -158,8 +161,8 @@ impl BBDTree {
node.index = begin; node.index = begin;
// Calculate the bounding box // Calculate the bounding box
let mut lower_bound = vec![0f64; d]; let mut lower_bound = vec![T::zero(); d];
let mut upper_bound = vec![0f64; d]; let mut upper_bound = vec![T::zero(); d];
for i in 0..d { for i in 0..d {
lower_bound[i] = data.get(self.index[begin],i); lower_bound[i] = data.get(self.index[begin],i);
@@ -179,11 +182,11 @@ impl BBDTree {
} }
// Calculate bounding box stats // Calculate bounding box stats
let mut max_radius = -1.; let mut max_radius = T::from(-1.).unwrap();
let mut split_index = 0; let mut split_index = 0;
for i in 0..d { for i in 0..d {
node.center[i] = (lower_bound[i] + upper_bound[i]) / 2.; node.center[i] = (lower_bound[i] + upper_bound[i]) / T::two();
node.radius[i] = (upper_bound[i] - lower_bound[i]) / 2.; node.radius[i] = (upper_bound[i] - lower_bound[i]) / T::two();
if node.radius[i] > max_radius { if node.radius[i] > max_radius {
max_radius = node.radius[i]; max_radius = node.radius[i];
split_index = i; split_index = i;
@@ -191,7 +194,7 @@ impl BBDTree {
} }
// If the max spread is 0, make this a leaf node // If the max spread is 0, make this a leaf node
if max_radius < 1E-10 { if max_radius < T::from(1E-10).unwrap() {
node.lower = Option::None; node.lower = Option::None;
node.upper = Option::None; node.upper = Option::None;
for i in 0..d { for i in 0..d {
@@ -201,11 +204,11 @@ impl BBDTree {
if end > begin + 1 { if end > begin + 1 {
let len = end - begin; let len = end - begin;
for i in 0..d { for i in 0..d {
node.sum[i] *= len as f64; node.sum[i] = node.sum[i] * T::from(len).unwrap();
} }
} }
node.cost = 0f64; node.cost = T::zero();
return self.add_node(node); return self.add_node(node);
} }
@@ -247,9 +250,9 @@ impl BBDTree {
node.sum[i] = self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i]; node.sum[i] = self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
} }
let mut mean = vec![0f64; d]; let mut mean = vec![T::zero(); d];
for i in 0..d { for i in 0..d {
mean[i] = node.sum[i] / node.count as f64; mean[i] = node.sum[i] / T::from(node.count).unwrap();
} }
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean) + BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean); node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean) + BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
@@ -257,17 +260,17 @@ impl BBDTree {
self.add_node(node) self.add_node(node)
} }
fn node_cost(node: &BBDTreeNode, center: &Vec<f64>) -> f64 { fn node_cost(node: &BBDTreeNode<T>, center: &Vec<T>) -> T {
let d = center.len(); let d = center.len();
let mut scatter = 0f64; let mut scatter = T::zero();
for i in 0..d { for i in 0..d {
let x = (node.sum[i] / node.count as f64) - center[i]; let x = (node.sum[i] / T::from(node.count).unwrap()) - center[i];
scatter += x * x; scatter = scatter + x * x;
} }
node.cost + node.count as f64 * scatter node.cost + T::from(node.count).unwrap() * scatter
} }
fn add_node(&mut self, new_node: BBDTreeNode) -> usize{ fn add_node(&mut self, new_node: BBDTreeNode<T>) -> usize{
let idx = self.nodes.len(); let idx = self.nodes.len();
self.nodes.push(new_node); self.nodes.push(new_node);
idx idx
+33 -33
View File
@@ -1,29 +1,29 @@
use crate::math;
use crate::algorithm::neighbour::KNNAlgorithm;
use crate::algorithm::sort::heap_select::HeapSelect;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::iter::FromIterator; use std::iter::FromIterator;
use std::fmt::Debug; use std::fmt::Debug;
use std::cmp::{PartialOrd};
use core::hash::{Hash, Hasher}; use core::hash::{Hash, Hasher};
pub struct CoverTree<'a, T> use crate::math::num::FloatExt;
use crate::algorithm::neighbour::KNNAlgorithm;
use crate::algorithm::sort::heap_select::HeapSelect;
pub struct CoverTree<'a, T, F: FloatExt>
where T: Debug where T: Debug
{ {
base: f64, base: F,
max_level: i8, max_level: i8,
min_level: i8, min_level: i8,
distance: &'a dyn Fn(&T, &T) -> f64, distance: &'a dyn Fn(&T, &T) -> F,
nodes: Vec<Node<T>> nodes: Vec<Node<T>>
} }
impl<'a, T> CoverTree<'a, T> impl<'a, T, F: FloatExt> CoverTree<'a, T, F>
where T: Debug where T: Debug
{ {
pub fn new(mut data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> f64) -> CoverTree<T> { pub fn new(mut data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> CoverTree<T, F> {
let mut tree = CoverTree { let mut tree = CoverTree {
base: 2f64, base: F::two(),
max_level: 100, max_level: 100,
min_level: 100, min_level: 100,
distance: distance, distance: distance,
@@ -46,15 +46,15 @@ where T: Debug
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data))); let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
let mut i = self.max_level; let mut i = self.max_level;
loop { loop {
let i_d = self.base.powf(i as f64); let i_d = self.base.powf(F::from(i).unwrap());
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i); let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
let d_p_q = self.min_by_distance(&q_p_ds); let d_p_q = self.min_by_distance(&q_p_ds);
if d_p_q < math::EPSILON { if d_p_q < F::epsilon() {
return return
} else if d_p_q > i_d { } else if d_p_q > i_d {
break; break;
} }
if self.min_by_distance(&qi_p_ds) <= self.base.powf(i as f64){ if self.min_by_distance(&qi_p_ds) <= self.base.powf(F::from(i).unwrap()){
parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, _)| n.index); parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, _)| n.index);
p_i = i; p_i = i;
} }
@@ -82,7 +82,7 @@ where T: Debug
node_id node_id
} }
fn split(&self, p_id: NodeId, r: f64, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){ fn split(&self, p_id: NodeId, r: F, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){
let mut my_near = (Vec::new(), Vec::new()); let mut my_near = (Vec::new(), Vec::new());
@@ -96,7 +96,7 @@ where T: Debug
} }
fn split_remove_s(&self, p_id: NodeId, r: f64, s: &mut Vec<T>, mut my_near: (Vec<T>, Vec<T>)) -> (Vec<T>, Vec<T>){ fn split_remove_s(&self, p_id: NodeId, r: F, s: &mut Vec<T>, mut my_near: (Vec<T>, Vec<T>)) -> (Vec<T>, Vec<T>){
if s.len() > 0 { if s.len() > 0 {
let p = &self.nodes.get(p_id.index).unwrap().data; let p = &self.nodes.get(p_id.index).unwrap().data;
@@ -105,7 +105,7 @@ where T: Debug
let d = (self.distance)(p, &s[i]); let d = (self.distance)(p, &s[i]);
if d <= r { if d <= r {
my_near.0.push(s.remove(i)); my_near.0.push(s.remove(i));
} else if d > r && d <= 2f64 * r{ } else if d > r && d <= F::two() * r{
my_near.1.push(s.remove(i)); my_near.1.push(s.remove(i));
} else { } else {
i += 1; i += 1;
@@ -122,15 +122,15 @@ where T: Debug
self.min_level = std::cmp::min(self.min_level, i); self.min_level = std::cmp::min(self.min_level, i);
return (p, far); return (p, far);
} else { } else {
let (my, n) = self.split(p, self.base.powf((i-1) as f64), &mut near, None); let (my, n) = self.split(p, self.base.powf(F::from(i-1).unwrap()), &mut near, None);
let (pi, mut near) = self.construct(p, my, n, i-1); let (pi, mut near) = self.construct(p, my, n, i-1);
while near.len() > 0 { while near.len() > 0 {
let q_data = near.remove(0); let q_data = near.remove(0);
let nn = self.new_node(Some(p), q_data); let nn = self.new_node(Some(p), q_data);
let (my, n) = self.split(nn, self.base.powf((i-1) as f64), &mut near, Some(&mut far)); let (my, n) = self.split(nn, self.base.powf(F::from(i-1).unwrap()), &mut near, Some(&mut far));
let (child, mut unused) = self.construct(nn, my, n, i-1); let (child, mut unused) = self.construct(nn, my, n, i-1);
self.add_child(pi, child, i); self.add_child(pi, child, i);
let new_near_far = self.split(p, self.base.powf(i as f64), &mut unused, None); let new_near_far = self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None);
near.extend(new_near_far.0); near.extend(new_near_far.0);
far.extend(new_near_far.1); far.extend(new_near_far.1);
} }
@@ -148,9 +148,9 @@ where T: Debug
self.nodes.first().unwrap() self.nodes.first().unwrap()
} }
fn get_children_dist<'b>(&'b self, p: &T, qi_p_ds: &Vec<(&'b Node<T>, f64)>, i: i8) -> Vec<(&'b Node<T>, f64)> { fn get_children_dist<'b>(&'b self, p: &T, qi_p_ds: &Vec<(&'b Node<T>, F)>, i: i8) -> Vec<(&'b Node<T>, F)> {
let mut children = Vec::<(&'b Node<T>, f64)>::new(); let mut children = Vec::<(&'b Node<T>, F)>::new();
children.extend(qi_p_ds.iter().cloned()); children.extend(qi_p_ds.iter().cloned());
@@ -162,7 +162,7 @@ where T: Debug
} }
fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, f64)>, k: usize) -> f64 { fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, F)>, k: usize) -> F {
let mut heap = HeapSelect::with_capacity(k); let mut heap = HeapSelect::with_capacity(k);
for (_, d) in q_p_ds { for (_, d) in q_p_ds {
heap.add(d); heap.add(d);
@@ -171,7 +171,7 @@ where T: Debug
*heap.get().pop().unwrap() *heap.get().pop().unwrap()
} }
fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, f64)>) -> f64 { fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, F)>) -> F {
q_p_ds.into_iter().min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()).unwrap().1 q_p_ds.into_iter().min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()).unwrap().1
} }
@@ -180,7 +180,7 @@ where T: Debug
} }
#[allow(dead_code)] #[allow(dead_code)]
fn check_invariant(&self, invariant: fn(&CoverTree<T>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) { fn check_invariant(&self, invariant: fn(&CoverTree<T, F>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
let mut current_nodes: Vec<&Node<T>> = Vec::new(); let mut current_nodes: Vec<&Node<T>> = Vec::new();
current_nodes.push(self.root()); current_nodes.push(self.root());
for i in (self.min_level..self.max_level+1).rev() { for i in (self.min_level..self.max_level+1).rev() {
@@ -193,7 +193,7 @@ where T: Debug
} }
#[allow(dead_code)] #[allow(dead_code)]
fn nesting_invariant(_: &CoverTree<T>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) { fn nesting_invariant(_: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n)); let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n)); let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
for n in nodes_set.iter() { for n in nodes_set.iter() {
@@ -202,11 +202,11 @@ where T: Debug
} }
#[allow(dead_code)] #[allow(dead_code)]
fn covering_tree(tree: &CoverTree<T>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) { fn covering_tree(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
let mut p_selected: Vec<&Node<T>> = Vec::new(); let mut p_selected: Vec<&Node<T>> = Vec::new();
for p in next_nodes { for p in next_nodes {
for q in nodes { for q in nodes {
if (tree.distance)(&p.data, &q.data) <= tree.base.powf(i as f64) { if (tree.distance)(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
p_selected.push(*p); p_selected.push(*p);
} }
} }
@@ -216,11 +216,11 @@ where T: Debug
} }
#[allow(dead_code)] #[allow(dead_code)]
fn separation(tree: &CoverTree<T>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) { fn separation(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
for p in nodes { for p in nodes {
for q in nodes { for q in nodes {
if p != q { if p != q {
assert!((tree.distance)(&p.data, &q.data) > tree.base.powf(i as f64)); assert!((tree.distance)(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
} }
} }
} }
@@ -228,13 +228,13 @@ where T: Debug
} }
impl<'a, T> KNNAlgorithm<T> for CoverTree<'a, T> impl<'a, T, F: FloatExt> KNNAlgorithm<T> for CoverTree<'a, T, F>
where T: Debug where T: Debug
{ {
fn find(&self, p: &T, k: usize) -> Vec<usize>{ fn find(&self, p: &T, k: usize) -> Vec<usize>{
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data))); let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
for i in (self.min_level..self.max_level+1).rev() { for i in (self.min_level..self.max_level+1).rev() {
let i_d = self.base.powf(i as f64); let i_d = self.base.powf(F::from(i).unwrap());
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i); let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k); let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect(); qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
@@ -286,7 +286,7 @@ mod tests {
let distance = |a: &i32, b: &i32| -> f64 { let distance = |a: &i32, b: &i32| -> f64 {
(a - b).abs() as f64 (a - b).abs() as f64
}; };
let mut tree = CoverTree::<i32>::new(data, &distance); let mut tree = CoverTree::<i32, f64>::new(data, &distance);
for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) { for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) {
tree.insert(d); tree.insert(d);
} }
@@ -309,7 +309,7 @@ mod tests {
let distance = |a: &i32, b: &i32| -> f64 { let distance = |a: &i32, b: &i32| -> f64 {
(a - b).abs() as f64 (a - b).abs() as f64
}; };
let tree = CoverTree::<i32>::new(data, &distance); let tree = CoverTree::<i32, f64>::new(data, &distance);
tree.check_invariant(CoverTree::nesting_invariant); tree.check_invariant(CoverTree::nesting_invariant);
tree.check_invariant(CoverTree::covering_tree); tree.check_invariant(CoverTree::covering_tree);
tree.check_invariant(CoverTree::separation); tree.check_invariant(CoverTree::separation);
+11 -11
View File
@@ -3,19 +3,19 @@ use crate::algorithm::sort::heap_select::HeapSelect;
use std::cmp::{Ordering, PartialOrd}; use std::cmp::{Ordering, PartialOrd};
use num_traits::Float; use num_traits::Float;
pub struct LinearKNNSearch<'a, T> { pub struct LinearKNNSearch<'a, T, F: Float> {
distance: Box<dyn Fn(&T, &T) -> f64 + 'a>, distance: Box<dyn Fn(&T, &T) -> F + 'a>,
data: Vec<T> data: Vec<T>
} }
impl<'a, T> KNNAlgorithm<T> for LinearKNNSearch<'a, T> impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
{ {
fn find(&self, from: &T, k: usize) -> Vec<usize> { fn find(&self, from: &T, k: usize) -> Vec<usize> {
if k < 1 || k > self.data.len() { if k < 1 || k > self.data.len() {
panic!("k should be >= 1 and <= length(data)"); panic!("k should be >= 1 and <= length(data)");
} }
let mut heap = HeapSelect::<KNNPoint>::with_capacity(k); let mut heap = HeapSelect::<KNNPoint<F>>::with_capacity(k);
for _ in 0..k { for _ in 0..k {
heap.add(KNNPoint{ heap.add(KNNPoint{
@@ -41,8 +41,8 @@ impl<'a, T> KNNAlgorithm<T> for LinearKNNSearch<'a, T>
} }
} }
impl<'a, T> LinearKNNSearch<'a, T> { impl<'a, T, F: Float> LinearKNNSearch<'a, T, F> {
pub fn new(data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> f64) -> LinearKNNSearch<T>{ pub fn new(data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> LinearKNNSearch<T, F>{
LinearKNNSearch{ LinearKNNSearch{
data: data, data: data,
distance: Box::new(distance) distance: Box::new(distance)
@@ -51,24 +51,24 @@ impl<'a, T> LinearKNNSearch<'a, T> {
} }
#[derive(Debug)] #[derive(Debug)]
struct KNNPoint { struct KNNPoint<F: Float> {
distance: f64, distance: F,
index: Option<usize> index: Option<usize>
} }
impl PartialOrd for KNNPoint { impl<F: Float> PartialOrd for KNNPoint<F> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance) self.distance.partial_cmp(&other.distance)
} }
} }
impl PartialEq for KNNPoint { impl<F: Float> PartialEq for KNNPoint<F> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.distance == other.distance self.distance == other.distance
} }
} }
impl Eq for KNNPoint {} impl<F: Float> Eq for KNNPoint<F> {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
+3 -1
View File
@@ -1,8 +1,10 @@
use num_traits::Float;
pub trait QuickArgSort { pub trait QuickArgSort {
fn quick_argsort(&mut self) -> Vec<usize>; fn quick_argsort(&mut self) -> Vec<usize>;
} }
impl QuickArgSort for Vec<f64> { impl<T: Float> QuickArgSort for Vec<T> {
fn quick_argsort(&mut self) -> Vec<usize> { fn quick_argsort(&mut self) -> Vec<usize> {
let stack_size = 64; let stack_size = 64;
+26 -20
View File
@@ -1,18 +1,21 @@
extern crate rand; extern crate rand;
use rand::Rng; use rand::Rng;
use std::iter::Sum;
use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::distance::euclidian; use crate::math::distance::euclidian;
use crate::algorithm::neighbour::bbd_tree::BBDTree; use crate::algorithm::neighbour::bbd_tree::BBDTree;
#[derive(Debug)] #[derive(Debug)]
pub struct KMeans { pub struct KMeans<T: FloatExt> {
k: usize, k: usize,
y: Vec<usize>, y: Vec<usize>,
size: Vec<usize>, size: Vec<usize>,
distortion: f64, distortion: T,
centroids: Vec<Vec<f64>> centroids: Vec<Vec<T>>
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -28,8 +31,8 @@ impl Default for KMeansParameters {
} }
} }
impl KMeans{ impl<T: FloatExt + Debug + Sum> KMeans<T>{
pub fn new<M: Matrix>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans { pub fn new<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans<T> {
let bbd = BBDTree::new(data); let bbd = BBDTree::new(data);
@@ -43,10 +46,10 @@ impl KMeans{
let (n, d) = data.shape(); let (n, d) = data.shape();
let mut distortion = std::f64::MAX; let mut distortion = T::max_value();
let mut y = KMeans::kmeans_plus_plus(data, k); let mut y = KMeans::kmeans_plus_plus(data, k);
let mut size = vec![0; k]; let mut size = vec![0; k];
let mut centroids = vec![vec![0f64; d]; k]; let mut centroids = vec![vec![T::zero(); d]; k];
for i in 0..n { for i in 0..n {
size[y[i]] += 1; size[y[i]] += 1;
@@ -54,23 +57,23 @@ impl KMeans{
for i in 0..n { for i in 0..n {
for j in 0..d { for j in 0..d {
centroids[y[i]][j] += data.get(i, j); centroids[y[i]][j] = centroids[y[i]][j] + data.get(i, j);
} }
} }
for i in 0..k { for i in 0..k {
for j in 0..d { for j in 0..d {
centroids[i][j] /= size[i] as f64; centroids[i][j] = centroids[i][j] / T::from(size[i]).unwrap();
} }
} }
let mut sums = vec![vec![0f64; d]; k]; let mut sums = vec![vec![T::zero(); d]; k];
for _ in 1..= parameters.max_iter { for _ in 1..= parameters.max_iter {
let dist = bbd.clustering(&centroids, &mut sums, &mut size, &mut y); let dist = bbd.clustering(&centroids, &mut sums, &mut size, &mut y);
for i in 0..k { for i in 0..k {
if size[i] > 0 { if size[i] > 0 {
for j in 0..d { for j in 0..d {
centroids[i][j] = sums[i][j] as f64 / size[i] as f64; centroids[i][j] = T::from(sums[i][j]).unwrap() / T::from(size[i]).unwrap();
} }
} }
} }
@@ -92,13 +95,13 @@ impl KMeans{
} }
} }
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let (n, _) = x.shape(); let (n, _) = x.shape();
let mut result = M::zeros(1, n); let mut result = M::zeros(1, n);
for i in 0..n { for i in 0..n {
let mut min_dist = std::f64::MAX; let mut min_dist = T::max_value();
let mut best_cluster = 0; let mut best_cluster = 0;
for j in 0..self.k { for j in 0..self.k {
@@ -108,19 +111,19 @@ impl KMeans{
best_cluster = j; best_cluster = j;
} }
} }
result.set(0, i, best_cluster as f64); result.set(0, i, T::from(best_cluster).unwrap());
} }
result.to_row_vector() result.to_row_vector()
} }
fn kmeans_plus_plus<M: Matrix>(data: &M, k: usize) -> Vec<usize>{ fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize>{
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let (n, _) = data.shape(); let (n, _) = data.shape();
let mut y = vec![0; n]; let mut y = vec![0; n];
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n)); let mut centroid = data.get_row_as_vec(rng.gen_range(0, n));
let mut d = vec![std::f64::MAX; n]; let mut d = vec![T::max_value(); n];
// pick the next center // pick the next center
for j in 1..k { for j in 1..k {
@@ -136,12 +139,15 @@ impl KMeans{
} }
} }
let sum: f64 = d.iter().sum(); let mut sum: T = T::zero();
let cutoff = rng.gen::<f64>() * sum; for i in d.iter(){
let mut cost = 0f64; sum = sum + *i;
}
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
let mut cost = T::zero();
let index = 0; let index = 0;
for index in 0..n { for index in 0..n {
cost += d[index]; cost = cost + d[index];
if cost >= cutoff { if cost >= cutoff {
break; break;
} }
+14 -12
View File
@@ -1,12 +1,14 @@
use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::{Matrix}; use crate::linalg::{Matrix};
#[derive(Debug)] #[derive(Debug)]
pub struct PCA<M: Matrix> { pub struct PCA<T: FloatExt + Debug, M: Matrix<T>> {
eigenvectors: M, eigenvectors: M,
eigenvalues: Vec<f64>, eigenvalues: Vec<T>,
projection: M, projection: M,
mu: Vec<f64>, mu: Vec<T>,
pmu: Vec<f64> pmu: Vec<T>
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -22,9 +24,9 @@ impl Default for PCAParameters {
} }
} }
impl<M: Matrix> PCA<M> { impl<T: FloatExt + Debug, M: Matrix<T>> PCA<T, M> {
pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA<M> { pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA<T, M> {
let (m, n) = data.shape(); let (m, n) = data.shape();
@@ -46,7 +48,7 @@ impl<M: Matrix> PCA<M> {
let svd = x.svd(); let svd = x.svd();
eigenvalues = svd.s; eigenvalues = svd.s;
for i in 0..eigenvalues.len() { for i in 0..eigenvalues.len() {
eigenvalues[i] *= eigenvalues[i]; eigenvalues[i] = eigenvalues[i] * eigenvalues[i];
} }
eigenvectors = svd.V; eigenvectors = svd.V;
@@ -63,13 +65,13 @@ impl<M: Matrix> PCA<M> {
for i in 0..n { for i in 0..n {
for j in 0..=i { for j in 0..=i {
cov.div_element_mut(i, j, m as f64); cov.div_element_mut(i, j, T::from(m).unwrap());
cov.set(j, i, cov.get(i, j)); cov.set(j, i, cov.get(i, j));
} }
} }
if parameters.use_correlation_matrix { if parameters.use_correlation_matrix {
let mut sd = vec![0f64; n]; let mut sd = vec![T::zero(); n];
for i in 0..n { for i in 0..n {
sd[i] = cov.get(i, i).sqrt(); sd[i] = cov.get(i, i).sqrt();
} }
@@ -110,10 +112,10 @@ impl<M: Matrix> PCA<M> {
} }
} }
let mut pmu = vec![0f64; n_components]; let mut pmu = vec![T::zero(); n_components];
for k in 0..n { for k in 0..n {
for i in 0..n_components { for i in 0..n_components {
pmu[i] += projection.get(i, k) * mu[k]; pmu[i] = pmu[i] + projection.get(i, k) * mu[k];
} }
} }
@@ -149,7 +151,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
fn us_arrests_data() -> DenseMatrix { fn us_arrests_data() -> DenseMatrix<f64> {
DenseMatrix::from_array(&[ DenseMatrix::from_array(&[
&[13.2, 236.0, 58.0, 21.2], &[13.2, 236.0, 58.0, 21.2],
&[10.0, 263.0, 48.0, 44.5], &[10.0, 263.0, 48.0, 44.5],
+24 -13
View File
@@ -1,7 +1,11 @@
extern crate rand; extern crate rand;
use rand::Rng;
use std::default::Default; use std::default::Default;
use std::fmt::Debug;
use rand::Rng;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max}; use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
@@ -16,10 +20,10 @@ pub struct RandomForestClassifierParameters {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct RandomForestClassifier { pub struct RandomForestClassifier<T: FloatExt> {
parameters: RandomForestClassifierParameters, parameters: RandomForestClassifierParameters,
trees: Vec<DecisionTreeClassifier>, trees: Vec<DecisionTreeClassifier<T>>,
classes: Vec<f64> classes: Vec<T>
} }
impl Default for RandomForestClassifierParameters { impl Default for RandomForestClassifierParameters {
@@ -35,9 +39,9 @@ impl Default for RandomForestClassifierParameters {
} }
} }
impl RandomForestClassifier { impl<T: FloatExt + Debug> RandomForestClassifier<T> {
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: RandomForestClassifierParameters) -> RandomForestClassifier { pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: RandomForestClassifierParameters) -> RandomForestClassifier<T> {
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
@@ -49,14 +53,14 @@ impl RandomForestClassifier {
yi[i] = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap();
} }
let mtry = parameters.mtry.unwrap_or((num_attributes as f64).sqrt().floor() as usize); let mtry = parameters.mtry.unwrap_or((T::from(num_attributes).unwrap()).sqrt().floor().to_usize().unwrap());
let classes = y_m.unique(); let classes = y_m.unique();
let k = classes.len(); let k = classes.len();
let mut trees: Vec<DecisionTreeClassifier> = Vec::new(); let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
for _ in 0..parameters.n_trees { for _ in 0..parameters.n_trees {
let samples = RandomForestClassifier::sample_with_replacement(&yi, k); let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k);
let params = DecisionTreeClassifierParameters{ let params = DecisionTreeClassifierParameters{
criterion: parameters.criterion.clone(), criterion: parameters.criterion.clone(),
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
@@ -74,7 +78,7 @@ impl RandomForestClassifier {
} }
} }
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -86,7 +90,7 @@ impl RandomForestClassifier {
result.to_row_vector() result.to_row_vector()
} }
fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> usize { fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
let mut result = vec![0; self.classes.len()]; let mut result = vec![0; self.classes.len()];
for tree in self.trees.iter() { for tree in self.trees.iter() {
@@ -154,9 +158,16 @@ mod tests {
&[5.2, 2.7, 3.9, 1.4]]); &[5.2, 2.7, 3.9, 1.4]]);
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
RandomForestClassifier::fit(&x, &y, Default::default()); let classifier = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters{
criterion: SplitCriterion::Gini,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
mtry: Option::None
});
assert_eq!(y, RandomForestClassifier::fit(&x, &y, Default::default()).predict(&x)); assert_eq!(y, classifier.predict(&x));
} }
+17 -13
View File
@@ -1,7 +1,11 @@
extern crate rand; extern crate rand;
use rand::Rng;
use std::default::Default; use std::default::Default;
use std::fmt::Debug;
use rand::Rng;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::tree::decision_tree_regressor::{DecisionTreeRegressor, DecisionTreeRegressorParameters}; use crate::tree::decision_tree_regressor::{DecisionTreeRegressor, DecisionTreeRegressorParameters};
@@ -15,9 +19,9 @@ pub struct RandomForestRegressorParameters {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct RandomForestRegressor { pub struct RandomForestRegressor<T: FloatExt> {
parameters: RandomForestRegressorParameters, parameters: RandomForestRegressorParameters,
trees: Vec<DecisionTreeRegressor> trees: Vec<DecisionTreeRegressor<T>>
} }
impl Default for RandomForestRegressorParameters { impl Default for RandomForestRegressorParameters {
@@ -32,17 +36,17 @@ impl Default for RandomForestRegressorParameters {
} }
} }
impl RandomForestRegressor { impl<T: FloatExt + Debug> RandomForestRegressor<T> {
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: RandomForestRegressorParameters) -> RandomForestRegressor { pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: RandomForestRegressorParameters) -> RandomForestRegressor<T> {
let (n_rows, num_attributes) = x.shape(); let (n_rows, num_attributes) = x.shape();
let mtry = parameters.mtry.unwrap_or((num_attributes as f64).sqrt().floor() as usize); let mtry = parameters.mtry.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut trees: Vec<DecisionTreeRegressor> = Vec::new(); let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
for _ in 0..parameters.n_trees { for _ in 0..parameters.n_trees {
let samples = RandomForestRegressor::sample_with_replacement(n_rows); let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows);
let params = DecisionTreeRegressorParameters{ let params = DecisionTreeRegressorParameters{
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
@@ -58,7 +62,7 @@ impl RandomForestRegressor {
} }
} }
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -70,17 +74,17 @@ impl RandomForestRegressor {
result.to_row_vector() result.to_row_vector()
} }
fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> f64 { fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
let n_trees = self.trees.len(); let n_trees = self.trees.len();
let mut result = 0f64; let mut result = T::zero();
for tree in self.trees.iter() { for tree in self.trees.iter() {
result += tree.predict_for_row(x, row); result = result + tree.predict_for_row(x, row);
} }
result / n_trees as f64 result / T::from(n_trees).unwrap()
} }
@@ -123,7 +127,7 @@ mod tests {
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); &[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
let expected_y = vec![85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.]; let expected_y: Vec<f64> = vec![85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.];
let y_hat = RandomForestRegressor::fit(&x, &y, let y_hat = RandomForestRegressor::fit(&x, &y,
RandomForestRegressorParameters{max_depth: None, RandomForestRegressorParameters{max_depth: None,
+156 -150
View File
@@ -2,16 +2,18 @@
use num::complex::Complex; use num::complex::Complex;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::math::num::FloatExt;
use std::fmt::Debug;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct EVD<M: BaseMatrix> { pub struct EVD<T: FloatExt + Debug, M: BaseMatrix<T>> {
pub d: Vec<f64>, pub d: Vec<T>,
pub e: Vec<f64>, pub e: Vec<T>,
pub V: M pub V: M
} }
impl<M: BaseMatrix> EVD<M> { impl<T: FloatExt + Debug, M: BaseMatrix<T>> EVD<T, M> {
pub fn new(V: M, d: Vec<f64>, e: Vec<f64>) -> EVD<M> { pub fn new(V: M, d: Vec<T>, e: Vec<T>) -> EVD<T, M> {
EVD { EVD {
d: d, d: d,
e: e, e: e,
@@ -20,21 +22,21 @@ impl<M: BaseMatrix> EVD<M> {
} }
} }
pub trait EVDDecomposableMatrix: BaseMatrix { pub trait EVDDecomposableMatrix<T: FloatExt + Debug>: BaseMatrix<T> {
fn evd(&self, symmetric: bool) -> EVD<Self>{ fn evd(&self, symmetric: bool) -> EVD<T, Self>{
self.clone().evd_mut(symmetric) self.clone().evd_mut(symmetric)
} }
fn evd_mut(mut self, symmetric: bool) -> EVD<Self>{ fn evd_mut(mut self, symmetric: bool) -> EVD<T, Self>{
let(nrows, ncols) = self.shape(); let(nrows, ncols) = self.shape();
if ncols != nrows { if ncols != nrows {
panic!("Matrix is not square: {} x {}", nrows, ncols); panic!("Matrix is not square: {} x {}", nrows, ncols);
} }
let n = nrows; let n = nrows;
let mut d = vec![0f64; n]; let mut d = vec![T::zero(); n];
let mut e = vec![0f64; n]; let mut e = vec![T::zero(); n];
let mut V; let mut V;
if symmetric { if symmetric {
@@ -66,7 +68,7 @@ pub trait EVDDecomposableMatrix: BaseMatrix {
} }
} }
fn tred2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) { fn tred2<T: FloatExt + Debug, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
let (n, _) = V.shape(); let (n, _) = V.shape();
for i in 0..n { for i in 0..n {
@@ -76,34 +78,34 @@ fn tred2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
// Householder reduction to tridiagonal form. // Householder reduction to tridiagonal form.
for i in (1..n).rev() { for i in (1..n).rev() {
// Scale to avoid under/overflow. // Scale to avoid under/overflow.
let mut scale = 0f64; let mut scale = T::zero();
let mut h = 0f64; let mut h = T::zero();
for k in 0..i { for k in 0..i {
scale = scale + d[k].abs(); scale = scale + d[k].abs();
} }
if scale == 0f64 { if scale == T::zero() {
e[i] = d[i - 1]; e[i] = d[i - 1];
for j in 0..i { for j in 0..i {
d[j] = V.get(i - 1, j); d[j] = V.get(i - 1, j);
V.set(i, j, 0.0); V.set(i, j, T::zero());
V.set(j, i, 0.0); V.set(j, i, T::zero());
} }
} else { } else {
// Generate Householder vector. // Generate Householder vector.
for k in 0..i { for k in 0..i {
d[k] /= scale; d[k] = d[k] / scale;
h += d[k] * d[k]; h = h + d[k] * d[k];
} }
let mut f = d[i - 1]; let mut f = d[i - 1];
let mut g = h.sqrt(); let mut g = h.sqrt();
if f > 0f64 { if f > T::zero() {
g = -g; g = -g;
} }
e[i] = scale * g; e[i] = scale * g;
h = h - f * g; h = h - f * g;
d[i - 1] = f - g; d[i - 1] = f - g;
for j in 0..i { for j in 0..i {
e[j] = 0f64; e[j] = T::zero();
} }
// Apply similarity transformation to remaining columns. // Apply similarity transformation to remaining columns.
@@ -112,19 +114,19 @@ fn tred2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
V.set(j, i, f); V.set(j, i, f);
g = e[j] + V.get(j, j) * f; g = e[j] + V.get(j, j) * f;
for k in j + 1..=i - 1 { for k in j + 1..=i - 1 {
g += V.get(k, j) * d[k]; g = g + V.get(k, j) * d[k];
e[k] += V.get(k, j) * f; e[k] = e[k] + V.get(k, j) * f;
} }
e[j] = g; e[j] = g;
} }
f = 0.0; f = T::zero();
for j in 0..i { for j in 0..i {
e[j] /= h; e[j] = e[j] / h;
f += e[j] * d[j]; f = f + e[j] * d[j];
} }
let hh = f / (h + h); let hh = f / (h + h);
for j in 0..i { for j in 0..i {
e[j] -= hh * d[j]; e[j] = e[j] - hh * d[j];
} }
for j in 0..i { for j in 0..i {
f = d[j]; f = d[j];
@@ -133,7 +135,7 @@ fn tred2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
V.sub_element_mut(k, j, f * e[k] + g * d[k]); V.sub_element_mut(k, j, f * e[k] + g * d[k]);
} }
d[j] = V.get(i - 1, j); d[j] = V.get(i - 1, j);
V.set(i, j, 0.0); V.set(i, j, T::zero());
} }
} }
d[i] = h; d[i] = h;
@@ -142,16 +144,16 @@ fn tred2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
// Accumulate transformations. // Accumulate transformations.
for i in 0..n-1 { for i in 0..n-1 {
V.set(n - 1, i, V.get(i, i)); V.set(n - 1, i, V.get(i, i));
V.set(i, i, 1.0); V.set(i, i, T::one());
let h = d[i + 1]; let h = d[i + 1];
if h != 0f64 { if h != T::zero() {
for k in 0..=i { for k in 0..=i {
d[k] = V.get(k, i + 1) / h; d[k] = V.get(k, i + 1) / h;
} }
for j in 0..=i { for j in 0..=i {
let mut g = 0f64; let mut g = T::zero();
for k in 0..=i { for k in 0..=i {
g += V.get(k, i + 1) * V.get(k, j); g = g + V.get(k, i + 1) * V.get(k, j);
} }
for k in 0..=i { for k in 0..=i {
V.sub_element_mut(k, j, g * d[k]); V.sub_element_mut(k, j, g * d[k]);
@@ -159,35 +161,35 @@ fn tred2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
} }
} }
for k in 0..=i { for k in 0..=i {
V.set(k, i + 1, 0.0); V.set(k, i + 1, T::zero());
} }
} }
for j in 0..n { for j in 0..n {
d[j] = V.get(n - 1, j); d[j] = V.get(n - 1, j);
V.set(n - 1, j, 0.0); V.set(n - 1, j, T::zero());
} }
V.set(n - 1, n - 1, 1.0); V.set(n - 1, n - 1, T::one());
e[0] = 0.0; e[0] = T::zero();
} }
fn tql2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) { fn tql2<T: FloatExt + Debug, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
let (n, _) = V.shape(); let (n, _) = V.shape();
for i in 1..n { for i in 1..n {
e[i - 1] = e[i]; e[i - 1] = e[i];
} }
e[n - 1] = 0f64; e[n - 1] = T::zero();
let mut f = 0f64; let mut f = T::zero();
let mut tst1 = 0f64; let mut tst1 = T::zero();
for l in 0..n { for l in 0..n {
// Find small subdiagonal element // Find small subdiagonal element
tst1 = f64::max(tst1, d[l].abs() + e[l].abs()); tst1 = T::max(tst1, d[l].abs() + e[l].abs());
let mut m = l; let mut m = l;
loop { loop {
if m < n { if m < n {
if e[m].abs() <= tst1 * std::f64::EPSILON { if e[m].abs() <= tst1 * T::epsilon() {
break; break;
} }
m += 1; m += 1;
@@ -208,9 +210,9 @@ fn tql2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
// Compute implicit shift // Compute implicit shift
let mut g = d[l]; let mut g = d[l];
let mut p = (d[l + 1] - g) / (2.0 * e[l]); let mut p = (d[l + 1] - g) / (T::two() * e[l]);
let mut r = p.hypot(1.0); let mut r = p.hypot(T::one());
if p < 0f64 { if p < T::zero() {
r = -r; r = -r;
} }
d[l] = e[l] / (p + r); d[l] = e[l] / (p + r);
@@ -218,18 +220,18 @@ fn tql2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
let dl1 = d[l + 1]; let dl1 = d[l + 1];
let mut h = g - d[l]; let mut h = g - d[l];
for i in l+2..n { for i in l+2..n {
d[i] -= h; d[i] = d[i] - h;
} }
f = f + h; f = f + h;
// Implicit QL transformation. // Implicit QL transformation.
p = d[m]; p = d[m];
let mut c = 1.0; let mut c = T::one();
let mut c2 = c; let mut c2 = c;
let mut c3 = c; let mut c3 = c;
let el1 = e[l + 1]; let el1 = e[l + 1];
let mut s = 0.0; let mut s = T::zero();
let mut s2 = 0.0; let mut s2 = T::zero();
for i in (l..m).rev() { for i in (l..m).rev() {
c3 = c2; c3 = c2;
c2 = c; c2 = c;
@@ -255,13 +257,13 @@ fn tql2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
d[l] = c * p; d[l] = c * p;
// Check for convergence. // Check for convergence.
if e[l].abs() <= tst1 * std::f64::EPSILON { if e[l].abs() <= tst1 * T::epsilon() {
break; break;
} }
} }
} }
d[l] = d[l] + f; d[l] = d[l] + f;
e[l] = 0f64; e[l] = T::zero();
} }
// Sort eigenvalues and corresponding vectors. // Sort eigenvalues and corresponding vectors.
@@ -286,43 +288,45 @@ fn tql2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
} }
} }
fn balance<M: BaseMatrix>(A: &mut M) -> Vec<f64> { fn balance<T: FloatExt + Debug, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
let radix = 2f64; let radix = T::two();
let sqrdx = radix * radix; let sqrdx = radix * radix;
let (n, _) = A.shape(); let (n, _) = A.shape();
let mut scale = vec![1f64; n]; let mut scale = vec![T::one(); n];
let t = T::from(0.95).unwrap();
let mut done = false; let mut done = false;
while !done { while !done {
done = true; done = true;
for i in 0..n { for i in 0..n {
let mut r = 0f64; let mut r = T::zero();
let mut c = 0f64; let mut c = T::zero();
for j in 0..n { for j in 0..n {
if j != i { if j != i {
c += A.get(j, i).abs(); c = c + A.get(j, i).abs();
r += A.get(i, j).abs(); r = r + A.get(i, j).abs();
} }
} }
if c != 0f64 && r != 0f64 { if c != T::zero() && r != T::zero() {
let mut g = r / radix; let mut g = r / radix;
let mut f = 1.0; let mut f = T::one();
let s = c + r; let s = c + r;
while c < g { while c < g {
f *= radix; f = f * radix;
c *= sqrdx; c = c * sqrdx;
} }
g = r * radix; g = r * radix;
while c > g { while c > g {
f /= radix; f = f / radix;
c /= sqrdx; c = c / sqrdx;
} }
if (c + r) / f < 0.95 * s { if (c + r) / f < t * s {
done = false; done = false;
g = 1.0 / f; g = T::one() / f;
scale[i] *= f; scale[i] = scale[i] * f;
for j in 0..n { for j in 0..n {
A.mul_element_mut(i, j, g); A.mul_element_mut(i, j, g);
} }
@@ -337,12 +341,12 @@ fn balance<M: BaseMatrix>(A: &mut M) -> Vec<f64> {
return scale; return scale;
} }
fn elmhes<M: BaseMatrix>(A: &mut M) -> Vec<usize> { fn elmhes<T: FloatExt + Debug, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
let (n, _) = A.shape(); let (n, _) = A.shape();
let mut perm = vec![0; n]; let mut perm = vec![0; n];
for m in 1..n-1 { for m in 1..n-1 {
let mut x = 0f64; let mut x = T::zero();
let mut i = m; let mut i = m;
for j in m..n { for j in m..n {
if A.get(j, m - 1).abs() > x.abs() { if A.get(j, m - 1).abs() > x.abs() {
@@ -363,11 +367,11 @@ fn elmhes<M: BaseMatrix>(A: &mut M) -> Vec<usize> {
A.set(j, m, swap); A.set(j, m, swap);
} }
} }
if x != 0f64 { if x != T::zero() {
for i in (m + 1)..n { for i in (m + 1)..n {
let mut y = A.get(i, m - 1); let mut y = A.get(i, m - 1);
if y != 0f64 { if y != T::zero() {
y /= x; y = y / x;
A.set(i, m - 1, y); A.set(i, m - 1, y);
for j in m..n { for j in m..n {
A.sub_element_mut(i, j, y * A.get(m, j)); A.sub_element_mut(i, j, y * A.get(m, j));
@@ -383,7 +387,7 @@ fn elmhes<M: BaseMatrix>(A: &mut M) -> Vec<usize> {
return perm; return perm;
} }
fn eltran<M: BaseMatrix>(A: &M, V: &mut M, perm: &Vec<usize>) { fn eltran<T: FloatExt + Debug, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &Vec<usize>) {
let (n, _) = A.shape(); let (n, _) = A.shape();
for mp in (1..n - 1).rev() { for mp in (1..n - 1).rev() {
for k in mp + 1..n { for k in mp + 1..n {
@@ -393,41 +397,41 @@ fn eltran<M: BaseMatrix>(A: &M, V: &mut M, perm: &Vec<usize>) {
if i != mp { if i != mp {
for j in mp..n { for j in mp..n {
V.set(mp, j, V.get(i, j)); V.set(mp, j, V.get(i, j));
V.set(i, j, 0.0); V.set(i, j, T::zero());
} }
V.set(i, mp, 1.0); V.set(i, mp, T::one());
} }
} }
} }
fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) { fn hqr2<T: FloatExt + Debug, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
let (n, _) = A.shape(); let (n, _) = A.shape();
let mut z = 0f64; let mut z = T::zero();
let mut s = 0f64; let mut s = T::zero();
let mut r = 0f64; let mut r = T::zero();
let mut q = 0f64; let mut q = T::zero();
let mut p = 0f64; let mut p = T::zero();
let mut anorm = 0f64; let mut anorm = T::zero();
for i in 0..n { for i in 0..n {
for j in i32::max(i as i32 - 1, 0)..n as i32 { for j in i32::max(i as i32 - 1, 0)..n as i32 {
anorm += A.get(i, j as usize).abs(); anorm = anorm + A.get(i, j as usize).abs();
} }
} }
let mut nn = n - 1; let mut nn = n - 1;
let mut t = 0.0; let mut t = T::zero();
'outer: loop { 'outer: loop {
let mut its = 0; let mut its = 0;
loop { loop {
let mut l = nn; let mut l = nn;
while l > 0 { while l > 0 {
s = A.get(l - 1, l - 1).abs() + A.get(l, l).abs(); s = A.get(l - 1, l - 1).abs() + A.get(l, l).abs();
if s == 0.0 { if s == T::zero() {
s = anorm; s = anorm;
} }
if A.get(l, l - 1).abs() <= std::f64::EPSILON * s { if A.get(l, l - 1).abs() <= T::epsilon() * s {
A.set(l, l - 1, 0.0); A.set(l, l - 1, T::zero());
break; break;
} }
l -= 1; l -= 1;
@@ -445,17 +449,17 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
let mut y = A.get(nn - 1, nn - 1); let mut y = A.get(nn - 1, nn - 1);
let mut w = A.get(nn, nn - 1) * A.get(nn - 1, nn); let mut w = A.get(nn, nn - 1) * A.get(nn - 1, nn);
if l == nn - 1 { if l == nn - 1 {
p = 0.5 * (y - x); p = T::half() * (y - x);
q = p * p + w; q = p * p + w;
z = q.abs().sqrt(); z = q.abs().sqrt();
x += t; x = x + t;
A.set(nn, nn, x ); A.set(nn, nn, x );
A.set(nn - 1, nn - 1, y + t); A.set(nn - 1, nn - 1, y + t);
if q >= 0.0 { if q >= T::zero() {
z = p + z.copysign(p); z = p + z.copysign(p);
d[nn - 1] = x + z; d[nn - 1] = x + z;
d[nn] = x + z; d[nn] = x + z;
if z != 0.0 { if z != T::zero() {
d[nn] = x - w / z; d[nn] = x - w / z;
} }
x = A.get(nn, nn - 1); x = A.get(nn, nn - 1);
@@ -463,8 +467,8 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
p = x / s; p = x / s;
q = z / s; q = z / s;
r = (p * p + q * q).sqrt(); r = (p * p + q * q).sqrt();
p /= r; p = p / r;
q /= r; q = q / r;
for j in nn-1..n { for j in nn-1..n {
z = A.get(nn - 1, j); z = A.get(nn - 1, j);
A.set(nn - 1, j, q * z + p * A.get(nn, j)); A.set(nn - 1, j, q * z + p * A.get(nn, j));
@@ -497,14 +501,14 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
panic!("Too many iterations in hqr"); panic!("Too many iterations in hqr");
} }
if its == 10 || its == 20 { if its == 10 || its == 20 {
t += x; t = t + x;
for i in 0..nn+1 { for i in 0..nn+1 {
A.sub_element_mut(i, i, x); A.sub_element_mut(i, i, x);
} }
s = A.get(nn, nn - 1).abs() + A.get(nn - 1, nn - 2).abs(); s = A.get(nn, nn - 1).abs() + A.get(nn - 1, nn - 2).abs();
y = 0.75 * s; y = T::from(0.75).unwrap() * s;
x = 0.75 * s; x = T::from(0.75).unwrap() * s;
w = -0.4375 * s * s; w = T::from(-0.4375).unwrap() * s * s;
} }
its += 1; its += 1;
let mut m = nn - 2; let mut m = nn - 2;
@@ -516,42 +520,42 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
q = A.get(m + 1, m + 1) - z - r - s; q = A.get(m + 1, m + 1) - z - r - s;
r = A.get(m + 2, m + 1); r = A.get(m + 2, m + 1);
s = p.abs() + q.abs() + r.abs(); s = p.abs() + q.abs() + r.abs();
p /= s; p = p / s;
q /= s; q = q / s;
r /= s; r = r / s;
if m == l { if m == l {
break; break;
} }
let u = A.get(m, m - 1).abs() * (q.abs() + r.abs()); let u = A.get(m, m - 1).abs() * (q.abs() + r.abs());
let v = p.abs() * (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs()); let v = p.abs() * (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs());
if u <= std::f64::EPSILON * v { if u <= T::epsilon() * v {
break; break;
} }
m -= 1; m -= 1;
} }
for i in m..nn-1 { for i in m..nn-1 {
A.set(i + 2, i , 0.0); A.set(i + 2, i , T::zero());
if i != m { if i != m {
A.set(i + 2, i - 1, 0.0); A.set(i + 2, i - 1, T::zero());
} }
} }
for k in m..nn { for k in m..nn {
if k != m { if k != m {
p = A.get(k, k - 1); p = A.get(k, k - 1);
q = A.get(k + 1, k - 1); q = A.get(k + 1, k - 1);
r = 0.0; r = T::zero();
if k + 1 != nn { if k + 1 != nn {
r = A.get(k + 2, k - 1); r = A.get(k + 2, k - 1);
} }
x = p.abs() + q.abs() +r.abs(); x = p.abs() + q.abs() +r.abs();
if x != 0.0 { if x != T::zero() {
p /= x; p = p / x;
q /= x; q = q / x;
r /= x; r = r / x;
} }
} }
let s = (p * p + q * q + r * r).sqrt().copysign(p); let s = (p * p + q * q + r * r).sqrt().copysign(p);
if s != 0.0 { if s != T::zero() {
if k == m { if k == m {
if l != m { if l != m {
A.set(k, k - 1, -A.get(k, k - 1)); A.set(k, k - 1, -A.get(k, k - 1));
@@ -559,16 +563,16 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
} else { } else {
A.set(k, k - 1, -s * x); A.set(k, k - 1, -s * x);
} }
p += s; p = p + s;
x = p / s; x = p / s;
y = q / s; y = q / s;
z = r / s; z = r / s;
q /= p; q = q / p;
r /= p; r = r / p;
for j in k..n { for j in k..n {
p = A.get(k, j) + q * A.get(k + 1, j); p = A.get(k, j) + q * A.get(k + 1, j);
if k + 1 != nn { if k + 1 != nn {
p += r * A.get(k + 2, j); p = p + r * A.get(k + 2, j);
A.sub_element_mut(k + 2, j, p * z); A.sub_element_mut(k + 2, j, p * z);
} }
A.sub_element_mut(k + 1, j, p * y); A.sub_element_mut(k + 1, j, p * y);
@@ -583,7 +587,7 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
for i in 0..mmin+1 { for i in 0..mmin+1 {
p = x * A.get(i, k) + y * A.get(i, k + 1); p = x * A.get(i, k) + y * A.get(i, k + 1);
if k + 1 != nn { if k + 1 != nn {
p += z * A.get(i, k + 2); p = p + z * A.get(i, k + 2);
A.sub_element_mut(i, k + 2, p * r); A.sub_element_mut(i, k + 2, p * r);
} }
A.sub_element_mut(i, k + 1, p * q); A.sub_element_mut(i, k + 1, p * q);
@@ -592,7 +596,7 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
for i in 0..n { for i in 0..n {
p = x * V.get(i, k) + y * V.get(i, k + 1); p = x * V.get(i, k) + y * V.get(i, k + 1);
if k + 1 != nn { if k + 1 != nn {
p += z * V.get(i, k + 2); p = p + z * V.get(i, k + 2);
V.sub_element_mut(i, k + 2, p * r); V.sub_element_mut(i, k + 2, p * r);
} }
V.sub_element_mut(i, k + 1, p * q); V.sub_element_mut(i, k + 1, p * q);
@@ -608,38 +612,38 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
}; };
} }
if anorm != 0f64 { if anorm != T::zero() {
for nn in (0..n).rev() { for nn in (0..n).rev() {
p = d[nn]; p = d[nn];
q = e[nn]; q = e[nn];
let na = nn.wrapping_sub(1); let na = nn.wrapping_sub(1);
if q == 0f64 { if q == T::zero() {
let mut m = nn; let mut m = nn;
A.set(nn, nn, 1.0); A.set(nn, nn, T::one());
if nn > 0 { if nn > 0 {
let mut i = nn - 1; let mut i = nn - 1;
loop { loop {
let w = A.get(i, i) - p; let w = A.get(i, i) - p;
r = 0.0; r = T::zero();
for j in m..=nn { for j in m..=nn {
r += A.get(i, j) * A.get(j, nn); r = r + A.get(i, j) * A.get(j, nn);
} }
if e[i] < 0.0 { if e[i] < T::zero() {
z = w; z = w;
s = r; s = r;
} else { } else {
m = i; m = i;
if e[i] == 0.0 { if e[i] == T::zero() {
t = w; t = w;
if t == 0.0 { if t == T::zero() {
t = std::f64::EPSILON * anorm; t = T::epsilon() * anorm;
} }
A.set(i, nn, -r / t); A.set(i, nn, -r / t);
} else { } else {
let x = A.get(i, i + 1); let x = A.get(i, i + 1);
let y = A.get(i + 1, i); let y = A.get(i + 1, i);
q = (d[i] - p).powf(2f64) + e[i].powf(2f64); q = (d[i] - p).powf(T::two()) + e[i].powf(T::two());
t = (x * s - z * r) / q; t = (x * s - z * r) / q;
A.set(i, nn, t); A.set(i, nn, t);
if x.abs() > z.abs() { if x.abs() > z.abs() {
@@ -649,7 +653,7 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
} }
} }
t = A.get(i, nn).abs(); t = A.get(i, nn).abs();
if std::f64::EPSILON * t * t > 1f64 { if T::epsilon() * t * t > T::one() {
for j in i..=nn { for j in i..=nn {
A.div_element_mut(j, nn, t); A.div_element_mut(j, nn, t);
} }
@@ -662,44 +666,44 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
} }
} }
} }
} else if q < 0f64 { } else if q < T::zero() {
let mut m = na; let mut m = na;
if A.get(nn, na).abs() > A.get(na, nn).abs() { if A.get(nn, na).abs() > A.get(na, nn).abs() {
A.set(na, na, q / A.get(nn, na)); A.set(na, na, q / A.get(nn, na));
A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na)); A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na));
} else { } else {
let temp = Complex::new(0.0, -A.get(na, nn)) / Complex::new(A.get(na, na) - p, q); let temp = Complex::new(T::zero(), -A.get(na, nn)) / Complex::new(A.get(na, na) - p, q);
A.set(na, na, temp.re); A.set(na, na, temp.re);
A.set(na, nn, temp.im); A.set(na, nn, temp.im);
} }
A.set(nn, na, 0.0); A.set(nn, na, T::zero());
A.set(nn, nn, 1.0); A.set(nn, nn, T::one());
if nn >= 2 { if nn >= 2 {
for i in (0..nn - 1).rev() { for i in (0..nn - 1).rev() {
let w = A.get(i, i) - p; let w = A.get(i, i) - p;
let mut ra = 0f64; let mut ra = T::zero();
let mut sa = 0f64; let mut sa = T::zero();
for j in m..=nn { for j in m..=nn {
ra += A.get(i, j) * A.get(j, na); ra = ra + A.get(i, j) * A.get(j, na);
sa += A.get(i, j) * A.get(j, nn); sa = sa + A.get(i, j) * A.get(j, nn);
} }
if e[i] < 0.0 { if e[i] < T::zero() {
z = w; z = w;
r = ra; r = ra;
s = sa; s = sa;
} else { } else {
m = i; m = i;
if e[i] == 0.0 { if e[i] == T::zero() {
let temp = Complex::new(-ra, -sa) / Complex::new(w, q); let temp = Complex::new(-ra, -sa) / Complex::new(w, q);
A.set(i, na, temp.re); A.set(i, na, temp.re);
A.set(i, nn, temp.im); A.set(i, nn, temp.im);
} else { } else {
let x = A.get(i, i + 1); let x = A.get(i, i + 1);
let y = A.get(i + 1, i); let y = A.get(i + 1, i);
let mut vr = (d[i] - p).powf(2f64) + (e[i]).powf(2.0) - q * q; let mut vr = (d[i] - p).powf(T::two()) + (e[i]).powf(T::two()) - q * q;
let vi = 2.0 * q * (d[i] - p); let vi = T::two() * q * (d[i] - p);
if vr == 0.0 && vi == 0.0 { if vr == T::zero() && vi == T::zero() {
vr = std::f64::EPSILON * anorm * (w.abs() + q.abs() + x.abs() + y.abs() + z.abs()); vr = T::epsilon() * anorm * (w.abs() + q.abs() + x.abs() + y.abs() + z.abs());
} }
let temp = Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra) / Complex::new(vr, vi); let temp = Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra) / Complex::new(vr, vi);
A.set(i, na, temp.re); A.set(i, na, temp.re);
@@ -714,8 +718,8 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
} }
} }
} }
t = f64::max(A.get(i, na).abs(), A.get(i, nn).abs()); t = T::max(A.get(i, na).abs(), A.get(i, nn).abs());
if std::f64::EPSILON * t * t > 1f64 { if T::epsilon() * t * t > T::one() {
for j in i..=nn { for j in i..=nn {
A.div_element_mut(j, na, t); A.div_element_mut(j, na, t);
A.div_element_mut(j, nn, t); A.div_element_mut(j, nn, t);
@@ -728,9 +732,9 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
for j in (0..n).rev() { for j in (0..n).rev() {
for i in 0..n { for i in 0..n {
z = 0f64; z = T::zero();
for k in 0..=j { for k in 0..=j {
z += V.get(i, k) * A.get(k, j); z = z + V.get(i, k) * A.get(k, j);
} }
V.set(i, j, z); V.set(i, j, z);
} }
@@ -738,7 +742,7 @@ fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>)
} }
} }
fn balbak<M: BaseMatrix>(V: &mut M, scale: &Vec<f64>) { fn balbak<T: FloatExt + Debug, M: BaseMatrix<T>>(V: &mut M, scale: &Vec<T>) {
let (n, _) = V.shape(); let (n, _) = V.shape();
for i in 0..n { for i in 0..n {
for j in 0..n { for j in 0..n {
@@ -747,9 +751,9 @@ fn balbak<M: BaseMatrix>(V: &mut M, scale: &Vec<f64>) {
} }
} }
fn sort<M: BaseMatrix>(d: &mut Vec<f64>, e: &mut Vec<f64>, V: &mut M) { fn sort<T: FloatExt + Debug, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut M) {
let n = d.len(); let n = d.len();
let mut temp = vec![0f64; n]; let mut temp = vec![T::zero(); n];
for j in 1..n { for j in 1..n {
let real = d[j]; let real = d[j];
let img = e[j]; let img = e[j];
@@ -789,7 +793,7 @@ mod tests {
&[0.4000, 0.5000, 0.3000], &[0.4000, 0.5000, 0.3000],
&[0.7000, 0.3000, 0.8000]]); &[0.7000, 0.3000, 0.8000]]);
let eigen_values = vec![1.7498382, 0.3165784, 0.1335834]; let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
let eigen_vectors = DenseMatrix::from_array(&[ let eigen_vectors = DenseMatrix::from_array(&[
&[0.6881997, -0.07121225, 0.7220180], &[0.6881997, -0.07121225, 0.7220180],
@@ -817,7 +821,7 @@ mod tests {
&[0.4000, 0.5000, 0.3000], &[0.4000, 0.5000, 0.3000],
&[0.8000, 0.3000, 0.8000]]); &[0.8000, 0.3000, 0.8000]]);
let eigen_values = vec![1.79171122, 0.31908143, 0.08920735]; let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735];
let eigen_vectors = DenseMatrix::from_array(&[ let eigen_vectors = DenseMatrix::from_array(&[
&[0.7178958, 0.05322098, 0.6812010], &[0.7178958, 0.05322098, 0.6812010],
@@ -826,6 +830,8 @@ mod tests {
]); ]);
let evd = A.evd(false); let evd = A.evd(false);
println!("{}", &evd.V.abs());
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4)); assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() { for i in 0..eigen_values.len() {
@@ -846,8 +852,8 @@ mod tests {
&[1.0, 1.0, 3.0, -2.0], &[1.0, 1.0, 3.0, -2.0],
&[1.0, 1.0, 4.0, -1.0]]); &[1.0, 1.0, 4.0, -1.0]]);
let eigen_values_d = vec![0.0, 2.0, 2.0, 0.0]; let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0];
let eigen_values_e = vec![2.2361, 0.9999, -0.9999, -2.2361]; let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361];
let eigen_vectors = DenseMatrix::from_array(&[ let eigen_vectors = DenseMatrix::from_array(&[
&[-0.9159, -0.1378, 0.3816, -0.0806], &[-0.9159, -0.1378, 0.3816, -0.0806],
+42 -37
View File
@@ -6,11 +6,14 @@ pub mod ndarray_bindings;
use std::ops::Range; use std::ops::Range;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData;
use crate::math::num::FloatExt;
use svd::SVDDecomposableMatrix; use svd::SVDDecomposableMatrix;
use evd::EVDDecomposableMatrix; use evd::EVDDecomposableMatrix;
use qr::QRDecomposableMatrix; use qr::QRDecomposableMatrix;
pub trait BaseMatrix: Clone + Debug { pub trait BaseMatrix<T: FloatExt + Debug>: Clone + Debug {
type RowVector: Clone + Debug; type RowVector: Clone + Debug;
@@ -18,13 +21,13 @@ pub trait BaseMatrix: Clone + Debug {
fn to_row_vector(self) -> Self::RowVector; fn to_row_vector(self) -> Self::RowVector;
fn get(&self, row: usize, col: usize) -> f64; fn get(&self, row: usize, col: usize) -> T;
fn get_row_as_vec(&self, row: usize) -> Vec<f64>; fn get_row_as_vec(&self, row: usize) -> Vec<T>;
fn get_col_as_vec(&self, col: usize) -> Vec<f64>; fn get_col_as_vec(&self, col: usize) -> Vec<T>;
fn set(&mut self, row: usize, col: usize, x: f64); fn set(&mut self, row: usize, col: usize, x: T);
fn eye(size: usize) -> Self; fn eye(size: usize) -> Self;
@@ -32,9 +35,9 @@ pub trait BaseMatrix: Clone + Debug {
fn ones(nrows: usize, ncols: usize) -> Self; fn ones(nrows: usize, ncols: usize) -> Self;
fn to_raw_vector(&self) -> Vec<f64>; fn to_raw_vector(&self) -> Vec<T>;
fn fill(nrows: usize, ncols: usize, value: f64) -> Self; fn fill(nrows: usize, ncols: usize, value: T) -> Self;
fn shape(&self) -> (usize, usize); fn shape(&self) -> (usize, usize);
@@ -44,11 +47,11 @@ pub trait BaseMatrix: Clone + Debug {
fn dot(&self, other: &Self) -> Self; fn dot(&self, other: &Self) -> Self;
fn vector_dot(&self, other: &Self) -> f64; fn vector_dot(&self, other: &Self) -> T;
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self; fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self;
fn approximate_eq(&self, other: &Self, error: f64) -> bool; fn approximate_eq(&self, other: &Self, error: T) -> bool;
fn add_mut(&mut self, other: &Self) -> &Self; fn add_mut(&mut self, other: &Self) -> &Self;
@@ -58,13 +61,13 @@ pub trait BaseMatrix: Clone + Debug {
fn div_mut(&mut self, other: &Self) -> &Self; fn div_mut(&mut self, other: &Self) -> &Self;
fn div_element_mut(&mut self, row: usize, col: usize, x: f64); fn div_element_mut(&mut self, row: usize, col: usize, x: T);
fn mul_element_mut(&mut self, row: usize, col: usize, x: f64); fn mul_element_mut(&mut self, row: usize, col: usize, x: T);
fn add_element_mut(&mut self, row: usize, col: usize, x: f64); fn add_element_mut(&mut self, row: usize, col: usize, x: T);
fn sub_element_mut(&mut self, row: usize, col: usize, x: f64); fn sub_element_mut(&mut self, row: usize, col: usize, x: T);
fn add(&self, other: &Self) -> Self { fn add(&self, other: &Self) -> Self {
let mut r = self.clone(); let mut r = self.clone();
@@ -90,33 +93,33 @@ pub trait BaseMatrix: Clone + Debug {
r r
} }
fn add_scalar_mut(&mut self, scalar: f64) -> &Self; fn add_scalar_mut(&mut self, scalar: T) -> &Self;
fn sub_scalar_mut(&mut self, scalar: f64) -> &Self; fn sub_scalar_mut(&mut self, scalar: T) -> &Self;
fn mul_scalar_mut(&mut self, scalar: f64) -> &Self; fn mul_scalar_mut(&mut self, scalar: T) -> &Self;
fn div_scalar_mut(&mut self, scalar: f64) -> &Self; fn div_scalar_mut(&mut self, scalar: T) -> &Self;
fn add_scalar(&self, scalar: f64) -> Self{ fn add_scalar(&self, scalar: T) -> Self{
let mut r = self.clone(); let mut r = self.clone();
r.add_scalar_mut(scalar); r.add_scalar_mut(scalar);
r r
} }
fn sub_scalar(&self, scalar: f64) -> Self{ fn sub_scalar(&self, scalar: T) -> Self{
let mut r = self.clone(); let mut r = self.clone();
r.sub_scalar_mut(scalar); r.sub_scalar_mut(scalar);
r r
} }
fn mul_scalar(&self, scalar: f64) -> Self{ fn mul_scalar(&self, scalar: T) -> Self{
let mut r = self.clone(); let mut r = self.clone();
r.mul_scalar_mut(scalar); r.mul_scalar_mut(scalar);
r r
} }
fn div_scalar(&self, scalar: f64) -> Self{ fn div_scalar(&self, scalar: T) -> Self{
let mut r = self.clone(); let mut r = self.clone();
r.div_scalar_mut(scalar); r.div_scalar_mut(scalar);
r r
@@ -126,11 +129,11 @@ pub trait BaseMatrix: Clone + Debug {
fn rand(nrows: usize, ncols: usize) -> Self; fn rand(nrows: usize, ncols: usize) -> Self;
fn norm2(&self) -> f64; fn norm2(&self) -> T;
fn norm(&self, p:f64) -> f64; fn norm(&self, p:T) -> T;
fn column_mean(&self) -> Vec<f64>; fn column_mean(&self) -> Vec<T>;
fn negative_mut(&mut self); fn negative_mut(&mut self);
@@ -152,15 +155,15 @@ pub trait BaseMatrix: Clone + Debug {
result result
} }
fn sum(&self) -> f64; fn sum(&self) -> T;
fn max_diff(&self, other: &Self) -> f64; fn max_diff(&self, other: &Self) -> T;
fn softmax_mut(&mut self); fn softmax_mut(&mut self);
fn pow_mut(&mut self, p: f64) -> &Self; fn pow_mut(&mut self, p: T) -> &Self;
fn pow(&mut self, p: f64) -> Self { fn pow(&mut self, p: T) -> Self {
let mut result = self.clone(); let mut result = self.clone();
result.pow_mut(p); result.pow_mut(p);
result result
@@ -168,31 +171,33 @@ pub trait BaseMatrix: Clone + Debug {
fn argmax(&self) -> Vec<usize>; fn argmax(&self) -> Vec<usize>;
fn unique(&self) -> Vec<f64>; fn unique(&self) -> Vec<T>;
} }
pub trait Matrix: BaseMatrix + SVDDecomposableMatrix + EVDDecomposableMatrix + QRDecomposableMatrix {} pub trait Matrix<T: FloatExt + Debug>: BaseMatrix<T> + SVDDecomposableMatrix<T> + EVDDecomposableMatrix<T> + QRDecomposableMatrix<T> {}
pub fn row_iter<M: Matrix>(m: &M) -> RowIter<M> { pub fn row_iter<F: FloatExt + Debug, M: Matrix<F>>(m: &M) -> RowIter<F, M> {
RowIter{ RowIter{
m: m, m: m,
pos: 0, pos: 0,
max_pos: m.shape().0 max_pos: m.shape().0,
phantom: PhantomData
} }
} }
pub struct RowIter<'a, M: Matrix> { pub struct RowIter<'a, T: FloatExt + Debug, M: Matrix<T>> {
m: &'a M, m: &'a M,
pos: usize, pos: usize,
max_pos: usize max_pos: usize,
phantom: PhantomData<&'a T>
} }
impl<'a, M: Matrix> Iterator for RowIter<'a, M> { impl<'a, T: FloatExt + Debug, M: Matrix<T>> Iterator for RowIter<'a, T, M> {
type Item = Vec<f64>; type Item = Vec<T>;
fn next(&mut self) -> Option<Vec<f64>> { fn next(&mut self) -> Option<Vec<T>> {
let res; let res;
if self.pos < self.max_pos { if self.pos < self.max_pos {
res = Some(self.m.get_row_as_vec(self.pos)) res = Some(self.m.get_row_as_vec(self.pos))
+98 -98
View File
@@ -1,36 +1,37 @@
extern crate num; extern crate num;
use std::ops::Range; use std::ops::Range;
use std::fmt; use std::fmt;
use std::fmt::Debug;
use crate::linalg::Matrix; use crate::linalg::Matrix;
pub use crate::linalg::BaseMatrix; pub use crate::linalg::BaseMatrix;
use crate::linalg::svd::SVDDecomposableMatrix; use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix;
use crate::math; use crate::math::num::FloatExt;
use rand::prelude::*;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct DenseMatrix { pub struct DenseMatrix<T: FloatExt + Debug> {
ncols: usize, ncols: usize,
nrows: usize, nrows: usize,
values: Vec<f64> values: Vec<T>
} }
impl fmt::Display for DenseMatrix { impl<T: FloatExt + Debug> fmt::Display for DenseMatrix<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut rows: Vec<Vec<f64>> = Vec::new(); let mut rows: Vec<Vec<f64>> = Vec::new();
for r in 0..self.nrows { for r in 0..self.nrows {
rows.push(self.get_row_as_vec(r).iter().map(|x| (x * 1e4).round() / 1e4 ).collect()); rows.push(self.get_row_as_vec(r).iter().map(|x| (x.to_f64().unwrap() * 1e4).round() / 1e4 ).collect());
} }
write!(f, "{:?}", rows) write!(f, "{:?}", rows)
} }
} }
impl DenseMatrix { impl<T: FloatExt + Debug> DenseMatrix<T> {
fn new(nrows: usize, ncols: usize, values: Vec<f64>) -> DenseMatrix { fn new(nrows: usize, ncols: usize, values: Vec<T>) -> Self {
DenseMatrix { DenseMatrix {
ncols: ncols, ncols: ncols,
nrows: nrows, nrows: nrows,
@@ -38,17 +39,17 @@ impl DenseMatrix {
} }
} }
pub fn from_array(values: &[&[f64]]) -> DenseMatrix { pub fn from_array(values: &[&[T]]) -> Self {
DenseMatrix::from_vec(&values.into_iter().map(|row| Vec::from(*row)).collect()) DenseMatrix::from_vec(&values.into_iter().map(|row| Vec::from(*row)).collect())
} }
pub fn from_vec(values: &Vec<Vec<f64>>) -> DenseMatrix { pub fn from_vec(values: &Vec<Vec<T>>) -> DenseMatrix<T> {
let nrows = values.len(); let nrows = values.len();
let ncols = values.first().unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector")).len(); let ncols = values.first().unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector")).len();
let mut m = DenseMatrix { let mut m = DenseMatrix {
ncols: ncols, ncols: ncols,
nrows: nrows, nrows: nrows,
values: vec![0f64; ncols*nrows] values: vec![T::zero(); ncols*nrows]
}; };
for row in 0..nrows { for row in 0..nrows {
for col in 0..ncols { for col in 0..ncols {
@@ -58,11 +59,11 @@ impl DenseMatrix {
m m
} }
pub fn vector_from_array(values: &[f64]) -> DenseMatrix { pub fn vector_from_array(values: &[T]) -> Self {
DenseMatrix::vector_from_vec(Vec::from(values)) DenseMatrix::vector_from_vec(Vec::from(values))
} }
pub fn vector_from_vec(values: Vec<f64>) -> DenseMatrix { pub fn vector_from_vec(values: Vec<T>) -> Self {
DenseMatrix { DenseMatrix {
ncols: values.len(), ncols: values.len(),
nrows: 1, nrows: 1,
@@ -70,31 +71,31 @@ impl DenseMatrix {
} }
} }
pub fn div_mut(&mut self, b: DenseMatrix) -> () { pub fn div_mut(&mut self, b: Self) -> () {
if self.nrows != b.nrows || self.ncols != b.ncols { if self.nrows != b.nrows || self.ncols != b.ncols {
panic!("Can't divide matrices of different sizes."); panic!("Can't divide matrices of different sizes.");
} }
for i in 0..self.values.len() { for i in 0..self.values.len() {
self.values[i] /= b.values[i]; self.values[i] = self.values[i] / b.values[i];
} }
} }
pub fn get_raw_values(&self) -> &Vec<f64> { pub fn get_raw_values(&self) -> &Vec<T> {
&self.values &self.values
} }
} }
impl SVDDecomposableMatrix for DenseMatrix {} impl<T: FloatExt + Debug> SVDDecomposableMatrix<T> for DenseMatrix<T> {}
impl EVDDecomposableMatrix for DenseMatrix {} impl<T: FloatExt + Debug> EVDDecomposableMatrix<T> for DenseMatrix<T> {}
impl QRDecomposableMatrix for DenseMatrix {} impl<T: FloatExt + Debug> QRDecomposableMatrix<T> for DenseMatrix<T> {}
impl Matrix for DenseMatrix {} impl<T: FloatExt + Debug> Matrix<T> for DenseMatrix<T> {}
impl PartialEq for DenseMatrix { impl<T: FloatExt + Debug> PartialEq for DenseMatrix<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows { if self.ncols != other.ncols || self.nrows != other.nrows {
return false return false
@@ -108,7 +109,7 @@ impl PartialEq for DenseMatrix {
} }
for i in 0..len { for i in 0..len {
if (self.values[i] - other.values[i]).abs() > math::EPSILON { if (self.values[i] - other.values[i]).abs() > T::epsilon() {
return false; return false;
} }
} }
@@ -117,15 +118,15 @@ impl PartialEq for DenseMatrix {
} }
} }
impl Into<Vec<f64>> for DenseMatrix { impl<T: FloatExt + Debug> Into<Vec<T>> for DenseMatrix<T> {
fn into(self) -> Vec<f64> { fn into(self) -> Vec<T> {
self.values self.values
} }
} }
impl BaseMatrix for DenseMatrix { impl<T: FloatExt + Debug> BaseMatrix<T> for DenseMatrix<T> {
type RowVector = Vec<f64>; type RowVector = Vec<T>;
fn from_row_vector(vec: Self::RowVector) -> Self{ fn from_row_vector(vec: Self::RowVector) -> Self{
DenseMatrix::new(1, vec.len(), vec) DenseMatrix::new(1, vec.len(), vec)
@@ -135,50 +136,50 @@ impl BaseMatrix for DenseMatrix {
self.to_raw_vector() self.to_raw_vector()
} }
fn get(&self, row: usize, col: usize) -> f64 { fn get(&self, row: usize, col: usize) -> T {
self.values[col*self.nrows + row] self.values[col*self.nrows + row]
} }
fn get_row_as_vec(&self, row: usize) -> Vec<f64>{ fn get_row_as_vec(&self, row: usize) -> Vec<T>{
let mut result = vec![0f64; self.ncols]; let mut result = vec![T::zero(); self.ncols];
for c in 0..self.ncols { for c in 0..self.ncols {
result[c] = self.get(row, c); result[c] = self.get(row, c);
} }
result result
} }
fn get_col_as_vec(&self, col: usize) -> Vec<f64>{ fn get_col_as_vec(&self, col: usize) -> Vec<T>{
let mut result = vec![0f64; self.nrows]; let mut result = vec![T::zero(); self.nrows];
for r in 0..self.nrows { for r in 0..self.nrows {
result[r] = self.get(r, col); result[r] = self.get(r, col);
} }
result result
} }
fn set(&mut self, row: usize, col: usize, x: f64) { fn set(&mut self, row: usize, col: usize, x: T) {
self.values[col*self.nrows + row] = x; self.values[col*self.nrows + row] = x;
} }
fn zeros(nrows: usize, ncols: usize) -> DenseMatrix { fn zeros(nrows: usize, ncols: usize) -> Self {
DenseMatrix::fill(nrows, ncols, 0f64) DenseMatrix::fill(nrows, ncols, T::zero())
} }
fn ones(nrows: usize, ncols: usize) -> DenseMatrix { fn ones(nrows: usize, ncols: usize) -> Self {
DenseMatrix::fill(nrows, ncols, 1f64) DenseMatrix::fill(nrows, ncols, T::one())
} }
fn eye(size: usize) -> Self { fn eye(size: usize) -> Self {
let mut matrix = Self::zeros(size, size); let mut matrix = Self::zeros(size, size);
for i in 0..size { for i in 0..size {
matrix.set(i, i, 1.0); matrix.set(i, i, T::one());
} }
return matrix; return matrix;
} }
fn to_raw_vector(&self) -> Vec<f64>{ fn to_raw_vector(&self) -> Vec<T>{
let mut v = vec![0.; self.nrows * self.ncols]; let mut v = vec![T::zero(); self.nrows * self.ncols];
for r in 0..self.nrows{ for r in 0..self.nrows{
for c in 0..self.ncols { for c in 0..self.ncols {
@@ -197,7 +198,7 @@ impl BaseMatrix for DenseMatrix {
if self.ncols != other.ncols { if self.ncols != other.ncols {
panic!("Number of columns in both matrices should be equal"); panic!("Number of columns in both matrices should be equal");
} }
let mut result = DenseMatrix::zeros(self.nrows + other.nrows, self.ncols); let mut result = Self::zeros(self.nrows + other.nrows, self.ncols);
for c in 0..self.ncols { for c in 0..self.ncols {
for r in 0..self.nrows+other.nrows { for r in 0..self.nrows+other.nrows {
if r < self.nrows { if r < self.nrows {
@@ -214,7 +215,7 @@ impl BaseMatrix for DenseMatrix {
if self.nrows != other.nrows { if self.nrows != other.nrows {
panic!("Number of rows in both matrices should be equal"); panic!("Number of rows in both matrices should be equal");
} }
let mut result = DenseMatrix::zeros(self.nrows, self.ncols + other.ncols); let mut result = Self::zeros(self.nrows, self.ncols + other.ncols);
for r in 0..self.nrows { for r in 0..self.nrows {
for c in 0..self.ncols+other.ncols { for c in 0..self.ncols+other.ncols {
if c < self.ncols { if c < self.ncols {
@@ -233,13 +234,13 @@ impl BaseMatrix for DenseMatrix {
panic!("Number of rows of A should equal number of columns of B"); panic!("Number of rows of A should equal number of columns of B");
} }
let inner_d = self.ncols; let inner_d = self.ncols;
let mut result = DenseMatrix::zeros(self.nrows, other.ncols); let mut result = Self::zeros(self.nrows, other.ncols);
for r in 0..self.nrows { for r in 0..self.nrows {
for c in 0..other.ncols { for c in 0..other.ncols {
let mut s = 0f64; let mut s = T::zero();
for i in 0..inner_d { for i in 0..inner_d {
s += self.get(r, i) * other.get(i, c); s = s + self.get(r, i) * other.get(i, c);
} }
result.set(r, c, s); result.set(r, c, s);
} }
@@ -248,7 +249,7 @@ impl BaseMatrix for DenseMatrix {
result result
} }
fn vector_dot(&self, other: &Self) -> f64 { fn vector_dot(&self, other: &Self) -> T {
if (self.nrows != 1 || self.nrows != 1) && (other.nrows != 1 || other.ncols != 1) { if (self.nrows != 1 || self.nrows != 1) && (other.nrows != 1 || other.ncols != 1) {
panic!("A and B should both be 1-dimentional vectors."); panic!("A and B should both be 1-dimentional vectors.");
} }
@@ -256,20 +257,20 @@ impl BaseMatrix for DenseMatrix {
panic!("A and B should have the same size"); panic!("A and B should have the same size");
} }
let mut result = 0f64; let mut result = T::zero();
for i in 0..(self.nrows * self.ncols) { for i in 0..(self.nrows * self.ncols) {
result += self.values[i] * other.values[i]; result = result + self.values[i] * other.values[i];
} }
result result
} }
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> DenseMatrix { fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self {
let ncols = cols.len(); let ncols = cols.len();
let nrows = rows.len(); let nrows = rows.len();
let mut m = DenseMatrix::new(nrows, ncols, vec![0f64; nrows * ncols]); let mut m = DenseMatrix::new(nrows, ncols, vec![T::zero(); nrows * ncols]);
for r in rows.start..rows.end { for r in rows.start..rows.end {
for c in cols.start..cols.end { for c in cols.start..cols.end {
@@ -280,7 +281,7 @@ impl BaseMatrix for DenseMatrix {
m m
} }
fn approximate_eq(&self, other: &Self, error: f64) -> bool { fn approximate_eq(&self, other: &Self, error: T) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows { if self.ncols != other.ncols || self.nrows != other.nrows {
return false return false
} }
@@ -296,7 +297,7 @@ impl BaseMatrix for DenseMatrix {
true true
} }
fn fill(nrows: usize, ncols: usize, value: f64) -> Self { fn fill(nrows: usize, ncols: usize, value: T) -> Self {
DenseMatrix::new(nrows, ncols, vec![value; ncols * nrows]) DenseMatrix::new(nrows, ncols, vec![value; ncols * nrows])
} }
@@ -352,27 +353,27 @@ impl BaseMatrix for DenseMatrix {
self self
} }
fn div_element_mut(&mut self, row: usize, col: usize, x: f64) { fn div_element_mut(&mut self, row: usize, col: usize, x: T) {
self.values[col*self.nrows + row] /= x; self.values[col*self.nrows + row] = self.values[col*self.nrows + row] / x;
} }
fn mul_element_mut(&mut self, row: usize, col: usize, x: f64) { fn mul_element_mut(&mut self, row: usize, col: usize, x: T) {
self.values[col*self.nrows + row] *= x; self.values[col*self.nrows + row] = self.values[col*self.nrows + row] * x;
} }
fn add_element_mut(&mut self, row: usize, col: usize, x: f64) { fn add_element_mut(&mut self, row: usize, col: usize, x: T) {
self.values[col*self.nrows + row] += x self.values[col*self.nrows + row] = self.values[col*self.nrows + row] + x
} }
fn sub_element_mut(&mut self, row: usize, col: usize, x: f64) { fn sub_element_mut(&mut self, row: usize, col: usize, x: T) {
self.values[col*self.nrows + row] -= x; self.values[col*self.nrows + row] = self.values[col*self.nrows + row] - x;
} }
fn transpose(&self) -> Self { fn transpose(&self) -> Self {
let mut m = DenseMatrix { let mut m = DenseMatrix {
ncols: self.nrows, ncols: self.nrows,
nrows: self.ncols, nrows: self.ncols,
values: vec![0f64; self.ncols * self.nrows] values: vec![T::zero(); self.ncols * self.nrows]
}; };
for c in 0..self.ncols { for c in 0..self.ncols {
for r in 0..self.nrows { for r in 0..self.nrows {
@@ -383,10 +384,9 @@ impl BaseMatrix for DenseMatrix {
} }
fn rand(nrows: usize, ncols: usize) -> Self { fn rand(nrows: usize, ncols: usize) -> Self {
let mut rng = rand::thread_rng(); let values: Vec<T> = (0..nrows*ncols).map(|_| {
let values: Vec<f64> = (0..nrows*ncols).map(|_| { T::rand()
rng.gen()
}).collect(); }).collect();
DenseMatrix { DenseMatrix {
ncols: ncols, ncols: ncols,
@@ -395,74 +395,74 @@ impl BaseMatrix for DenseMatrix {
} }
} }
fn norm2(&self) -> f64 { fn norm2(&self) -> T {
let mut norm = 0f64; let mut norm = T::zero();
for xi in self.values.iter() { for xi in self.values.iter() {
norm += xi * xi; norm = norm + *xi * *xi;
} }
norm.sqrt() norm.sqrt()
} }
fn norm(&self, p:f64) -> f64 { fn norm(&self, p:T) -> T {
if p.is_infinite() && p.is_sign_positive() { if p.is_infinite() && p.is_sign_positive() {
self.values.iter().map(|x| x.abs()).fold(std::f64::NEG_INFINITY, |a, b| a.max(b)) self.values.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b))
} else if p.is_infinite() && p.is_sign_negative() { } else if p.is_infinite() && p.is_sign_negative() {
self.values.iter().map(|x| x.abs()).fold(std::f64::INFINITY, |a, b| a.min(b)) self.values.iter().map(|x| x.abs()).fold(T::infinity(), |a, b| a.min(b))
} else { } else {
let mut norm = 0f64; let mut norm = T::zero();
for xi in self.values.iter() { for xi in self.values.iter() {
norm += xi.abs().powf(p); norm = norm + xi.abs().powf(p);
} }
norm.powf(1.0/p) norm.powf(T::one()/p)
} }
} }
fn column_mean(&self) -> Vec<f64> { fn column_mean(&self) -> Vec<T> {
let mut mean = vec![0f64; self.ncols]; let mut mean = vec![T::zero(); self.ncols];
for r in 0..self.nrows { for r in 0..self.nrows {
for c in 0..self.ncols { for c in 0..self.ncols {
mean[c] += self.get(r, c); mean[c] = mean[c] + self.get(r, c);
} }
} }
for i in 0..mean.len() { for i in 0..mean.len() {
mean[i] /= self.nrows as f64; mean[i] = mean[i] / T::from(self.nrows).unwrap();
} }
mean mean
} }
fn add_scalar_mut(&mut self, scalar: f64) -> &Self { fn add_scalar_mut(&mut self, scalar: T) -> &Self {
for i in 0..self.values.len() { for i in 0..self.values.len() {
self.values[i] += scalar; self.values[i] = self.values[i] + scalar;
} }
self self
} }
fn sub_scalar_mut(&mut self, scalar: f64) -> &Self { fn sub_scalar_mut(&mut self, scalar: T) -> &Self {
for i in 0..self.values.len() { for i in 0..self.values.len() {
self.values[i] -= scalar; self.values[i] = self.values[i] - scalar;
} }
self self
} }
fn mul_scalar_mut(&mut self, scalar: f64) -> &Self { fn mul_scalar_mut(&mut self, scalar: T) -> &Self {
for i in 0..self.values.len() { for i in 0..self.values.len() {
self.values[i] *= scalar; self.values[i] = self.values[i] * scalar;
} }
self self
} }
fn div_scalar_mut(&mut self, scalar: f64) -> &Self { fn div_scalar_mut(&mut self, scalar: T) -> &Self {
for i in 0..self.values.len() { for i in 0..self.values.len() {
self.values[i] /= scalar; self.values[i] = self.values[i] / scalar;
} }
self self
} }
@@ -512,8 +512,8 @@ impl BaseMatrix for DenseMatrix {
self self
} }
fn max_diff(&self, other: &Self) -> f64{ fn max_diff(&self, other: &Self) -> T{
let mut max_diff = 0f64; let mut max_diff = T::zero();
for i in 0..self.values.len() { for i in 0..self.values.len() {
max_diff = max_diff.max((self.values[i] - other.values[i]).abs()); max_diff = max_diff.max((self.values[i] - other.values[i]).abs());
} }
@@ -521,22 +521,22 @@ impl BaseMatrix for DenseMatrix {
} }
fn sum(&self) -> f64 { fn sum(&self) -> T {
let mut sum = 0.; let mut sum = T::zero();
for i in 0..self.values.len() { for i in 0..self.values.len() {
sum += self.values[i]; sum = sum + self.values[i];
} }
sum sum
} }
fn softmax_mut(&mut self) { fn softmax_mut(&mut self) {
let max = self.values.iter().map(|x| x.abs()).fold(std::f64::NEG_INFINITY, |a, b| a.max(b)); let max = self.values.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b));
let mut z = 0.; let mut z = T::zero();
for r in 0..self.nrows { for r in 0..self.nrows {
for c in 0..self.ncols { for c in 0..self.ncols {
let p = (self.get(r, c) - max).exp(); let p = (self.get(r, c) - max).exp();
self.set(r, c, p); self.set(r, c, p);
z += p; z = z + p;
} }
} }
for r in 0..self.nrows { for r in 0..self.nrows {
@@ -546,7 +546,7 @@ impl BaseMatrix for DenseMatrix {
} }
} }
fn pow_mut(&mut self, p: f64) -> &Self { fn pow_mut(&mut self, p: T) -> &Self {
for i in 0..self.values.len() { for i in 0..self.values.len() {
self.values[i] = self.values[i].powf(p); self.values[i] = self.values[i].powf(p);
} }
@@ -558,7 +558,7 @@ impl BaseMatrix for DenseMatrix {
let mut res = vec![0usize; self.nrows]; let mut res = vec![0usize; self.nrows];
for r in 0..self.nrows { for r in 0..self.nrows {
let mut max = std::f64::NEG_INFINITY; let mut max = T::neg_infinity();
let mut max_pos = 0usize; let mut max_pos = 0usize;
for c in 0..self.ncols { for c in 0..self.ncols {
let v = self.get(r, c); let v = self.get(r, c);
@@ -574,7 +574,7 @@ impl BaseMatrix for DenseMatrix {
} }
fn unique(&self) -> Vec<f64> { fn unique(&self) -> Vec<T> {
let mut result = self.values.clone(); let mut result = self.values.clone();
result.sort_by(|a, b| a.partial_cmp(b).unwrap()); result.sort_by(|a, b| a.partial_cmp(b).unwrap());
result.dedup(); result.dedup();
@@ -698,7 +698,7 @@ mod tests {
#[test] #[test]
fn rand() { fn rand() {
let m = DenseMatrix::rand(3, 3); let m: DenseMatrix<f64> = DenseMatrix::rand(3, 3);
for c in 0..3 { for c in 0..3 {
for r in 0..3 { for r in 0..3 {
assert!(m.get(r, c) != 0f64); assert!(m.get(r, c) != 0f64);
@@ -742,7 +742,7 @@ mod tests {
#[test] #[test]
fn softmax_mut() { fn softmax_mut() {
let mut prob = DenseMatrix::vector_from_array(&[1., 2., 3.]); let mut prob: DenseMatrix<f64> = DenseMatrix::vector_from_array(&[1., 2., 3.]);
prob.softmax_mut(); prob.softmax_mut();
assert!((prob.get(0, 0) - 0.09).abs() < 0.01); assert!((prob.get(0, 0) - 0.09).abs() < 0.01);
assert!((prob.get(0, 1) - 0.24).abs() < 0.01); assert!((prob.get(0, 1) - 0.24).abs() < 0.01);
+61 -52
View File
@@ -1,15 +1,25 @@
use std::ops::Range; use std::ops::Range;
use std::fmt::Debug;
use std::iter::Sum;
use std::ops::AddAssign;
use std::ops::SubAssign;
use std::ops::MulAssign;
use std::ops::DivAssign;
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
use ndarray::ScalarOperand;
use crate::math::num::FloatExt;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::linalg::svd::SVDDecomposableMatrix; use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix;
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
use rand::prelude::*;
impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
{ {
type RowVector = ArrayBase<OwnedRepr<f64>, Ix1>; type RowVector = ArrayBase<OwnedRepr<T>, Ix1>;
fn from_row_vector(vec: Self::RowVector) -> Self{ fn from_row_vector(vec: Self::RowVector) -> Self{
let vec_size = vec.len(); let vec_size = vec.len();
@@ -21,19 +31,19 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
self.into_shape(vec_size).unwrap() self.into_shape(vec_size).unwrap()
} }
fn get(&self, row: usize, col: usize) -> f64 { fn get(&self, row: usize, col: usize) -> T {
self[[row, col]] self[[row, col]]
} }
fn get_row_as_vec(&self, row: usize) -> Vec<f64> { fn get_row_as_vec(&self, row: usize) -> Vec<T> {
self.row(row).to_vec() self.row(row).to_vec()
} }
fn get_col_as_vec(&self, col: usize) -> Vec<f64> { fn get_col_as_vec(&self, col: usize) -> Vec<T> {
self.column(col).to_vec() self.column(col).to_vec()
} }
fn set(&mut self, row: usize, col: usize, x: f64) { fn set(&mut self, row: usize, col: usize, x: T) {
self[[row, col]] = x; self[[row, col]] = x;
} }
@@ -49,11 +59,11 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
Array::ones((nrows, ncols)) Array::ones((nrows, ncols))
} }
fn to_raw_vector(&self) -> Vec<f64> { fn to_raw_vector(&self) -> Vec<T> {
self.to_owned().iter().map(|v| *v).collect() self.to_owned().iter().map(|v| *v).collect()
} }
fn fill(nrows: usize, ncols: usize, value: f64) -> Self { fn fill(nrows: usize, ncols: usize, value: T) -> Self {
Array::from_elem((nrows, ncols), value) Array::from_elem((nrows, ncols), value)
} }
@@ -73,7 +83,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
self.dot(other) self.dot(other)
} }
fn vector_dot(&self, other: &Self) -> f64 { fn vector_dot(&self, other: &Self) -> T {
self.dot(&other.view().reversed_axes())[[0, 0]] self.dot(&other.view().reversed_axes())[[0, 0]]
} }
@@ -81,7 +91,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
self.slice(s![rows, cols]).to_owned() self.slice(s![rows, cols]).to_owned()
} }
fn approximate_eq(&self, other: &Self, error: f64) -> bool { fn approximate_eq(&self, other: &Self, error: T) -> bool {
(self - other).iter().all(|v| v.abs() <= error) (self - other).iter().all(|v| v.abs() <= error)
} }
@@ -105,22 +115,22 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
self self
} }
fn add_scalar_mut(&mut self, scalar: f64) -> &Self{ fn add_scalar_mut(&mut self, scalar: T) -> &Self{
*self += scalar; *self += scalar;
self self
} }
fn sub_scalar_mut(&mut self, scalar: f64) -> &Self{ fn sub_scalar_mut(&mut self, scalar: T) -> &Self{
*self -= scalar; *self -= scalar;
self self
} }
fn mul_scalar_mut(&mut self, scalar: f64) -> &Self{ fn mul_scalar_mut(&mut self, scalar: T) -> &Self{
*self *= scalar; *self *= scalar;
self self
} }
fn div_scalar_mut(&mut self, scalar: f64) -> &Self{ fn div_scalar_mut(&mut self, scalar: T) -> &Self{
*self /= scalar; *self /= scalar;
self self
} }
@@ -129,21 +139,20 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
self.clone().reversed_axes() self.clone().reversed_axes()
} }
fn rand(nrows: usize, ncols: usize) -> Self{ fn rand(nrows: usize, ncols: usize) -> Self{
let mut rng = rand::thread_rng(); let values: Vec<T> = (0..nrows*ncols).map(|_| {
let values: Vec<f64> = (0..nrows*ncols).map(|_| { T::rand()
rng.gen()
}).collect(); }).collect();
Array::from_shape_vec((nrows, ncols), values).unwrap() Array::from_shape_vec((nrows, ncols), values).unwrap()
} }
fn norm2(&self) -> f64{ fn norm2(&self) -> T{
self.iter().map(|x| x * x).sum::<f64>().sqrt() self.iter().map(|x| *x * *x).sum::<T>().sqrt()
} }
fn norm(&self, p:f64) -> f64 { fn norm(&self, p:T) -> T {
if p.is_infinite() && p.is_sign_positive() { if p.is_infinite() && p.is_sign_positive() {
self.iter().fold(std::f64::NEG_INFINITY, |f, &val| { self.iter().fold(T::neg_infinity(), |f, &val| {
let v = val.abs(); let v = val.abs();
if f > v { if f > v {
f f
@@ -152,7 +161,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
} }
}) })
} else if p.is_infinite() && p.is_sign_negative() { } else if p.is_infinite() && p.is_sign_negative() {
self.iter().fold(std::f64::INFINITY, |f, &val| { self.iter().fold(T::infinity(), |f, &val| {
let v = val.abs(); let v = val.abs();
if f < v { if f < v {
f f
@@ -162,38 +171,38 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
}) })
} else { } else {
let mut norm = 0f64; let mut norm = T::zero();
for xi in self.iter() { for xi in self.iter() {
norm += xi.abs().powf(p); norm = norm + xi.abs().powf(p);
} }
norm.powf(1.0/p) norm.powf(T::one()/p)
} }
} }
fn column_mean(&self) -> Vec<f64> { fn column_mean(&self) -> Vec<T> {
self.mean_axis(Axis(0)).unwrap().to_vec() self.mean_axis(Axis(0)).unwrap().to_vec()
} }
fn div_element_mut(&mut self, row: usize, col: usize, x: f64){ fn div_element_mut(&mut self, row: usize, col: usize, x: T){
self[[row, col]] /= x; self[[row, col]] = self[[row, col]] / x;
} }
fn mul_element_mut(&mut self, row: usize, col: usize, x: f64){ fn mul_element_mut(&mut self, row: usize, col: usize, x: T){
self[[row, col]] *= x; self[[row, col]] = self[[row, col]] * x;
} }
fn add_element_mut(&mut self, row: usize, col: usize, x: f64){ fn add_element_mut(&mut self, row: usize, col: usize, x: T){
self[[row, col]] += x; self[[row, col]] = self[[row, col]] + x;
} }
fn sub_element_mut(&mut self, row: usize, col: usize, x: f64){ fn sub_element_mut(&mut self, row: usize, col: usize, x: T){
self[[row, col]] -= x; self[[row, col]] = self[[row, col]] - x;
} }
fn negative_mut(&mut self){ fn negative_mut(&mut self){
*self *= -1.; *self *= -T::one();
} }
fn reshape(&self, nrows: usize, ncols: usize) -> Self{ fn reshape(&self, nrows: usize, ncols: usize) -> Self{
@@ -208,12 +217,12 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
self self
} }
fn sum(&self) -> f64{ fn sum(&self) -> T{
self.sum() self.sum()
} }
fn max_diff(&self, other: &Self) -> f64{ fn max_diff(&self, other: &Self) -> T{
let mut max_diff = 0f64; let mut max_diff = T::zero();
for r in 0..self.nrows() { for r in 0..self.nrows() {
for c in 0..self.ncols() { for c in 0..self.ncols() {
max_diff = max_diff.max((self[(r, c)] - other[(r, c)]).abs()); max_diff = max_diff.max((self[(r, c)] - other[(r, c)]).abs());
@@ -223,13 +232,13 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
} }
fn softmax_mut(&mut self){ fn softmax_mut(&mut self){
let max = self.iter().map(|x| x.abs()).fold(std::f64::NEG_INFINITY, |a, b| a.max(b)); let max = self.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b));
let mut z = 0.; let mut z = T::zero();
for r in 0..self.nrows() { for r in 0..self.nrows() {
for c in 0..self.ncols() { for c in 0..self.ncols() {
let p = (self[(r, c)] - max).exp(); let p = (self[(r, c)] - max).exp();
self.set(r, c, p); self.set(r, c, p);
z += p; z = z + p;
} }
} }
for r in 0..self.nrows() { for r in 0..self.nrows() {
@@ -239,7 +248,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
} }
} }
fn pow_mut(&mut self, p: f64) -> &Self{ fn pow_mut(&mut self, p: T) -> &Self{
for r in 0..self.nrows() { for r in 0..self.nrows() {
for c in 0..self.ncols() { for c in 0..self.ncols() {
self.set(r, c, self[(r, c)].powf(p)); self.set(r, c, self[(r, c)].powf(p));
@@ -252,7 +261,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
let mut res = vec![0usize; self.nrows()]; let mut res = vec![0usize; self.nrows()];
for r in 0..self.nrows() { for r in 0..self.nrows() {
let mut max = std::f64::NEG_INFINITY; let mut max = T::neg_infinity();
let mut max_pos = 0usize; let mut max_pos = 0usize;
for c in 0..self.ncols() { for c in 0..self.ncols() {
let v = self[(r, c)]; let v = self[(r, c)];
@@ -268,7 +277,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
} }
fn unique(&self) -> Vec<f64> { fn unique(&self) -> Vec<T> {
let mut result = self.clone().into_raw_vec(); let mut result = self.clone().into_raw_vec();
result.sort_by(|a, b| a.partial_cmp(b).unwrap()); result.sort_by(|a, b| a.partial_cmp(b).unwrap());
result.dedup(); result.dedup();
@@ -277,13 +286,13 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
} }
impl SVDDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {} impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl EVDDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {} impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl QRDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {} impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2> {} impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@@ -541,7 +550,7 @@ mod tests {
#[test] #[test]
fn softmax_mut(){ fn softmax_mut(){
let mut prob = arr2(&[[1., 2., 3.]]); let mut prob: Array2<f64> = arr2(&[[1., 2., 3.]]);
prob.softmax_mut(); prob.softmax_mut();
assert!((BaseMatrix::get(&prob, 0, 0) - 0.09).abs() < 0.01); assert!((BaseMatrix::get(&prob, 0, 0) - 0.09).abs() < 0.01);
assert!((BaseMatrix::get(&prob, 0, 1) - 0.24).abs() < 0.01); assert!((BaseMatrix::get(&prob, 0, 1) - 0.24).abs() < 0.01);
+24 -21
View File
@@ -1,20 +1,23 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct QR<M: BaseMatrix> { pub struct QR<T: FloatExt + Debug, M: BaseMatrix<T>> {
QR: M, QR: M,
tau: Vec<f64>, tau: Vec<T>,
singular: bool singular: bool
} }
impl<M: BaseMatrix> QR<M> { impl<T: FloatExt + Debug, M: BaseMatrix<T>> QR<T, M> {
pub fn new(QR: M, tau: Vec<f64>) -> QR<M> { pub fn new(QR: M, tau: Vec<T>) -> QR<T, M> {
let mut singular = false; let mut singular = false;
for j in 0..tau.len() { for j in 0..tau.len() {
if tau[j] == 0. { if tau[j] == T::zero() {
singular = true; singular = true;
break; break;
} }
@@ -44,12 +47,12 @@ impl<M: BaseMatrix> QR<M> {
let mut Q = M::zeros(m, n); let mut Q = M::zeros(m, n);
let mut k = n - 1; let mut k = n - 1;
loop { loop {
Q.set(k, k, 1.0); Q.set(k, k, T::one());
for j in k..n { for j in k..n {
if self.QR.get(k, k) != 0f64 { if self.QR.get(k, k) != T::zero() {
let mut s = 0f64; let mut s = T::zero();
for i in k..m { for i in k..m {
s += self.QR.get(i, k) * Q.get(i, j); s = s + self.QR.get(i, k) * Q.get(i, j);
} }
s = -s / self.QR.get(k, k); s = -s / self.QR.get(k, k);
for i in k..m { for i in k..m {
@@ -81,9 +84,9 @@ impl<M: BaseMatrix> QR<M> {
for k in 0..n { for k in 0..n {
for j in 0..b_ncols { for j in 0..b_ncols {
let mut s = 0f64; let mut s = T::zero();
for i in k..m { for i in k..m {
s += self.QR.get(i, k) * b.get(i, j); s = s + self.QR.get(i, k) * b.get(i, j);
} }
s = -s / self.QR.get(k, k); s = -s / self.QR.get(k, k);
for i in k..m { for i in k..m {
@@ -109,38 +112,38 @@ impl<M: BaseMatrix> QR<M> {
} }
} }
pub trait QRDecomposableMatrix: BaseMatrix { pub trait QRDecomposableMatrix<T: FloatExt + Debug>: BaseMatrix<T> {
fn qr(&self) -> QR<Self> { fn qr(&self) -> QR<T, Self> {
self.clone().qr_mut() self.clone().qr_mut()
} }
fn qr_mut(mut self) -> QR<Self> { fn qr_mut(mut self) -> QR<T, Self> {
let (m, n) = self.shape(); let (m, n) = self.shape();
let mut r_diagonal: Vec<f64> = vec![0f64; n]; let mut r_diagonal: Vec<T> = vec![T::zero(); n];
for k in 0..n { for k in 0..n {
let mut nrm = 0f64; let mut nrm = T::zero();
for i in k..m { for i in k..m {
nrm = nrm.hypot(self.get(i, k)); nrm = nrm.hypot(self.get(i, k));
} }
if nrm.abs() > std::f64::EPSILON { if nrm.abs() > T::epsilon() {
if self.get(k, k) < 0f64 { if self.get(k, k) < T::zero() {
nrm = -nrm; nrm = -nrm;
} }
for i in k..m { for i in k..m {
self.div_element_mut(i, k, nrm); self.div_element_mut(i, k, nrm);
} }
self.add_element_mut(k, k, 1f64); self.add_element_mut(k, k, T::one());
for j in k+1..n { for j in k+1..n {
let mut s = 0f64; let mut s = T::zero();
for i in k..m { for i in k..m {
s += self.get(i, k) * self.get(i, j); s = s + self.get(i, k) * self.get(i, j);
} }
s = -s / self.get(k, k); s = -s / self.get(k, k);
for i in k..m { for i in k..m {
+72 -70
View File
@@ -1,19 +1,21 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::math::num::FloatExt;
use std::fmt::Debug;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SVD<M: SVDDecomposableMatrix> { pub struct SVD<T: FloatExt + Debug, M: SVDDecomposableMatrix<T>> {
pub U: M, pub U: M,
pub V: M, pub V: M,
pub s: Vec<f64>, pub s: Vec<T>,
full: bool, full: bool,
m: usize, m: usize,
n: usize, n: usize,
tol: f64 tol: T
} }
pub trait SVDDecomposableMatrix: BaseMatrix { pub trait SVDDecomposableMatrix<T: FloatExt + Debug>: BaseMatrix<T> {
fn svd_solve_mut(self, b: Self) -> Self { fn svd_solve_mut(self, b: Self) -> Self {
self.svd_mut().solve(b) self.svd_mut().solve(b)
@@ -23,50 +25,50 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
self.svd().solve(b) self.svd().solve(b)
} }
fn svd(&self) -> SVD<Self> { fn svd(&self) -> SVD<T, Self> {
self.clone().svd_mut() self.clone().svd_mut()
} }
fn svd_mut(self) -> SVD<Self> { fn svd_mut(self) -> SVD<T, Self> {
let mut U = self; let mut U = self;
let (m, n) = U.shape(); let (m, n) = U.shape();
let (mut l, mut nm) = (0usize, 0usize); let (mut l, mut nm) = (0usize, 0usize);
let (mut anorm, mut g, mut scale) = (0f64, 0f64, 0f64); let (mut anorm, mut g, mut scale) = (T::zero(), T::zero(), T::zero());
let mut v = Self::zeros(n, n); let mut v = Self::zeros(n, n);
let mut w = vec![0f64; n]; let mut w = vec![T::zero(); n];
let mut rv1 = vec![0f64; n]; let mut rv1 = vec![T::zero(); n];
for i in 0..n { for i in 0..n {
l = i + 2; l = i + 2;
rv1[i] = scale * g; rv1[i] = scale * g;
g = 0f64; g = T::zero();
let mut s = 0f64; let mut s = T::zero();
scale = 0f64; scale = T::zero();
if i < m { if i < m {
for k in i..m { for k in i..m {
scale += U.get(k, i).abs(); scale = scale + U.get(k, i).abs();
} }
if scale.abs() > std::f64::EPSILON { if scale.abs() > T::epsilon() {
for k in i..m { for k in i..m {
U.div_element_mut(k, i, scale); U.div_element_mut(k, i, scale);
s += U.get(k, i) * U.get(k, i); s = s + U.get(k, i) * U.get(k, i);
} }
let mut f = U.get(i, i); let mut f = U.get(i, i);
g = -s.sqrt().copysign(f); g = -s.sqrt().copysign(f);
let h = f * g - s; let h = f * g - s;
U.set(i, i, f - g); U.set(i, i, f - g);
for j in l - 1..n { for j in l - 1..n {
s = 0f64; s = T::zero();
for k in i..m { for k in i..m {
s += U.get(k, i) * U.get(k, j); s = s + U.get(k, i) * U.get(k, j);
} }
f = s / h; f = s / h;
for k in i..m { for k in i..m {
@@ -80,19 +82,19 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
} }
w[i] = scale * g; w[i] = scale * g;
g = 0f64; g = T::zero();
let mut s = 0f64; let mut s = T::zero();
scale = 0f64; scale = T::zero();
if i + 1 <= m && i + 1 != n { if i + 1 <= m && i + 1 != n {
for k in l - 1..n { for k in l - 1..n {
scale += U.get(i, k).abs(); scale = scale + U.get(i, k).abs();
} }
if scale.abs() > std::f64::EPSILON { if scale.abs() > T::epsilon() {
for k in l - 1..n { for k in l - 1..n {
U.div_element_mut(i, k, scale); U.div_element_mut(i, k, scale);
s += U.get(i, k) * U.get(i, k); s = s + U.get(i, k) * U.get(i, k);
} }
let f = U.get(i, l - 1); let f = U.get(i, l - 1);
@@ -105,9 +107,9 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
} }
for j in l - 1..m { for j in l - 1..m {
s = 0f64; s = T::zero();
for k in l - 1..n { for k in l - 1..n {
s += U.get(j, k) * U.get(i, k); s = s + U.get(j, k) * U.get(i, k);
} }
for k in l - 1..n { for k in l - 1..n {
@@ -122,19 +124,19 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
} }
anorm = f64::max(anorm, w[i].abs() + rv1[i].abs()); anorm = T::max(anorm, w[i].abs() + rv1[i].abs());
} }
for i in (0..n).rev() { for i in (0..n).rev() {
if i < n - 1 { if i < n - 1 {
if g != 0.0 { if g != T::zero() {
for j in l..n { for j in l..n {
v.set(j, i, (U.get(i, j) / U.get(i, l)) / g); v.set(j, i, (U.get(i, j) / U.get(i, l)) / g);
} }
for j in l..n { for j in l..n {
let mut s = 0f64; let mut s = T::zero();
for k in l..n { for k in l..n {
s += U.get(i, k) * v.get(k, j); s = s + U.get(i, k) * v.get(k, j);
} }
for k in l..n { for k in l..n {
v.add_element_mut(k, j, s * v.get(k, i)); v.add_element_mut(k, j, s * v.get(k, i));
@@ -142,11 +144,11 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
} }
} }
for j in l..n { for j in l..n {
v.set(i, j, 0f64); v.set(i, j, T::zero());
v.set(j, i, 0f64); v.set(j, i, T::zero());
} }
} }
v.set(i, i, 1.0); v.set(i, i, T::one());
g = rv1[i]; g = rv1[i];
l = i; l = i;
} }
@@ -155,15 +157,15 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
l = i + 1; l = i + 1;
g = w[i]; g = w[i];
for j in l..n { for j in l..n {
U.set(i, j, 0f64); U.set(i, j, T::zero());
} }
if g.abs() > std::f64::EPSILON { if g.abs() > T::epsilon() {
g = 1f64 / g; g = T::one() / g;
for j in l..n { for j in l..n {
let mut s = 0f64; let mut s = T::zero();
for k in l..m { for k in l..m {
s += U.get(k, i) * U.get(k, j); s = s + U.get(k, i) * U.get(k, j);
} }
let f = (s / U.get(i, i)) * g; let f = (s / U.get(i, i)) * g;
for k in i..m { for k in i..m {
@@ -175,11 +177,11 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
} }
} else { } else {
for j in i..m { for j in i..m {
U.set(j, i, 0f64); U.set(j, i, T::zero());
} }
} }
U.add_element_mut(i, i, 1f64); U.add_element_mut(i, i, T::one());
} }
for k in (0..n).rev() { for k in (0..n).rev() {
@@ -187,30 +189,30 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
let mut flag = true; let mut flag = true;
l = k; l = k;
while l != 0 { while l != 0 {
if l == 0 || rv1[l].abs() <= std::f64::EPSILON * anorm { if l == 0 || rv1[l].abs() <= T::epsilon() * anorm {
flag = false; flag = false;
break; break;
} }
nm = l - 1; nm = l - 1;
if w[nm].abs() <= std::f64::EPSILON * anorm { if w[nm].abs() <= T::epsilon() * anorm {
break; break;
} }
l -= 1; l -= 1;
} }
if flag { if flag {
let mut c = 0.0; let mut c = T::zero();
let mut s = 1.0; let mut s = T::one();
for i in l..k+1 { for i in l..k+1 {
let f = s * rv1[i]; let f = s * rv1[i];
rv1[i] = c * rv1[i]; rv1[i] = c * rv1[i];
if f.abs() <= std::f64::EPSILON * anorm { if f.abs() <= T::epsilon() * anorm {
break; break;
} }
g = w[i]; g = w[i];
let mut h = f.hypot(g); let mut h = f.hypot(g);
w[i] = h; w[i] = h;
h = 1.0 / h; h = T::one() / h;
c = g * h; c = g * h;
s = -f * h; s = -f * h;
for j in 0..m { for j in 0..m {
@@ -224,7 +226,7 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
let z = w[k]; let z = w[k];
if l == k { if l == k {
if z < 0f64 { if z < T::zero() {
w[k] = -z; w[k] = -z;
for j in 0..n { for j in 0..n {
v.set(j, k, -v.get(j, k)); v.set(j, k, -v.get(j, k));
@@ -242,11 +244,11 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
let mut y = w[nm]; let mut y = w[nm];
g = rv1[nm]; g = rv1[nm];
let mut h = rv1[k]; let mut h = rv1[k];
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2.0 * h * y); let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y);
g = f.hypot(1.0); g = f.hypot(T::one());
f = ((x - z) * (x + z) + h * ((y / (f + g.copysign(f))) - h)) / x; f = ((x - z) * (x + z) + h * ((y / (f + g.copysign(f))) - h)) / x;
let mut c = 1f64; let mut c = T::one();
let mut s = 1f64; let mut s = T::one();
for j in l..=nm { for j in l..=nm {
let i = j + 1; let i = j + 1;
@@ -261,7 +263,7 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
f = x * c + g * s; f = x * c + g * s;
g = g * c - x * s; g = g * c - x * s;
h = y * s; h = y * s;
y *= c; y = y * c;
for jj in 0..n { for jj in 0..n {
x = v.get(jj, j); x = v.get(jj, j);
@@ -272,8 +274,8 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
z = f.hypot(h); z = f.hypot(h);
w[j] = z; w[j] = z;
if z.abs() > std::f64::EPSILON { if z.abs() > T::epsilon() {
z = 1.0 / z; z = T::one() / z;
c = f * z; c = f * z;
s = h * z; s = h * z;
} }
@@ -288,15 +290,15 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
} }
} }
rv1[l] = 0.0; rv1[l] = T::zero();
rv1[k] = f; rv1[k] = f;
w[k] = x; w[k] = x;
} }
} }
let mut inc = 1usize; let mut inc = 1usize;
let mut su = vec![0f64; m]; let mut su = vec![T::zero(); m];
let mut sv = vec![0f64; n]; let mut sv = vec![T::zero(); n];
loop { loop {
inc *= 3; inc *= 3;
@@ -347,12 +349,12 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
for k in 0..n { for k in 0..n {
let mut s = 0.; let mut s = 0.;
for i in 0..m { for i in 0..m {
if U.get(i, k) < 0. { if U.get(i, k) < T::zero() {
s += 1.; s += 1.;
} }
} }
for j in 0..n { for j in 0..n {
if v.get(j, k) < 0. { if v.get(j, k) < T::zero() {
s += 1.; s += 1.;
} }
} }
@@ -371,12 +373,12 @@ pub trait SVDDecomposableMatrix: BaseMatrix {
} }
} }
impl<M: SVDDecomposableMatrix> SVD<M> { impl<T: FloatExt + Debug, M: SVDDecomposableMatrix<T>> SVD<T, M> {
pub fn new(U: M, V: M, s: Vec<f64>) -> SVD<M> { pub fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
let m = U.shape().0; let m = U.shape().0;
let n = V.shape().0; let n = V.shape().0;
let full = s.len() == m.min(n); let full = s.len() == m.min(n);
let tol = 0.5 * ((m + n) as f64 + 1.).sqrt() * s[0] * std::f64::EPSILON; let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
SVD { SVD {
U: U, U: U,
V: V, V: V,
@@ -396,22 +398,22 @@ impl<M: SVDDecomposableMatrix> SVD<M> {
} }
for k in 0..p { for k in 0..p {
let mut tmp = vec![0f64; self.n]; let mut tmp = vec![T::zero(); self.n];
for j in 0..self.n { for j in 0..self.n {
let mut r = 0f64; let mut r = T::zero();
if self.s[j] > self.tol { if self.s[j] > self.tol {
for i in 0..self.m { for i in 0..self.m {
r += self.U.get(i, j) * b.get(i, k); r = r + self.U.get(i, j) * b.get(i, k);
} }
r /= self.s[j]; r = r / self.s[j];
} }
tmp[j] = r; tmp[j] = r;
} }
for j in 0..self.n { for j in 0..self.n {
let mut r = 0.0; let mut r = T::zero();
for jj in 0..self.n { for jj in 0..self.n {
r += self.V.get(j, jj) * tmp[jj]; r = r + self.V.get(j, jj) * tmp[jj];
} }
b.set(j, k, r); b.set(j, k, r);
} }
@@ -434,7 +436,7 @@ mod tests {
&[0.4000, 0.5000, 0.3000], &[0.4000, 0.5000, 0.3000],
&[0.7000, 0.3000, 0.8000]]); &[0.7000, 0.3000, 0.8000]]);
let s = vec![1.7498382, 0.3165784, 0.1335834]; let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
let U = DenseMatrix::from_array(&[ let U = DenseMatrix::from_array(&[
&[0.6881997, -0.07121225, 0.7220180], &[0.6881997, -0.07121225, 0.7220180],
@@ -471,7 +473,7 @@ mod tests {
&[0.12641406, -0.8710055, -0.2712301, 0.2296515, 1.1781535, -0.2158704, -0.27529472] &[0.12641406, -0.8710055, -0.2712301, 0.2296515, 1.1781535, -0.2158704, -0.27529472]
]); ]);
let s = vec![3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515]; let s: Vec<f64> = vec![3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515];
let U = DenseMatrix::from_array(&[ let U = DenseMatrix::from_array(&[
&[-0.3082776, 0.77676231, 0.01330514, 0.23231424, -0.47682758, 0.13927109, 0.02640713], &[-0.3082776, 0.77676231, 0.01330514, 0.23231424, -0.47682758, 0.13927109, 0.02640713],
+7 -5
View File
@@ -1,6 +1,8 @@
use crate::linalg::Matrix;
use std::fmt::Debug; use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::Matrix;
#[derive(Debug)] #[derive(Debug)]
pub enum LinearRegressionSolver { pub enum LinearRegressionSolver {
QR, QR,
@@ -8,15 +10,15 @@ pub enum LinearRegressionSolver {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct LinearRegression<M: Matrix> { pub struct LinearRegression<T: FloatExt + Debug, M: Matrix<T>> {
coefficients: M, coefficients: M,
intercept: f64, intercept: T,
solver: LinearRegressionSolver solver: LinearRegressionSolver
} }
impl<M: Matrix> LinearRegression<M> { impl<T: FloatExt + Debug, M: Matrix<T>> LinearRegression<T, M> {
pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression<M>{ pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression<T, M>{
let b = y.transpose(); let b = y.transpose();
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
+46 -37
View File
@@ -1,4 +1,7 @@
use crate::math::NumericExt; use std::fmt::Debug;
use std::marker::PhantomData;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult}; use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
@@ -6,42 +9,43 @@ use crate::optimization::line_search::Backtracking;
use crate::optimization::first_order::lbfgs::LBFGS; use crate::optimization::first_order::lbfgs::LBFGS;
#[derive(Debug)] #[derive(Debug)]
pub struct LogisticRegression<M: Matrix> { pub struct LogisticRegression<T: FloatExt + Debug, M: Matrix<T>> {
weights: M, weights: M,
classes: Vec<f64>, classes: Vec<T>,
num_attributes: usize, num_attributes: usize,
num_classes: usize num_classes: usize
} }
trait ObjectiveFunction<M: Matrix> { trait ObjectiveFunction<T: FloatExt + Debug, M: Matrix<T>> {
fn f(&self, w_bias: &M) -> f64; fn f(&self, w_bias: &M) -> T;
fn df(&self, g: &mut M, w_bias: &M); fn df(&self, g: &mut M, w_bias: &M);
fn partial_dot(w: &M, x: &M, v_col: usize, m_row: usize) -> f64 { fn partial_dot(w: &M, x: &M, v_col: usize, m_row: usize) -> T {
let mut sum = 0f64; let mut sum = T::zero();
let p = x.shape().1; let p = x.shape().1;
for i in 0..p { for i in 0..p {
sum += x.get(m_row, i) * w.get(0, i + v_col); sum = sum + x.get(m_row, i) * w.get(0, i + v_col);
} }
sum + w.get(0, p + v_col) sum + w.get(0, p + v_col)
} }
} }
struct BinaryObjectiveFunction<'a, M: Matrix> { struct BinaryObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix<T>> {
x: &'a M, x: &'a M,
y: Vec<usize> y: Vec<usize>,
phantom: PhantomData<&'a T>
} }
impl<'a, M: Matrix> ObjectiveFunction<M> for BinaryObjectiveFunction<'a, M> { impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
fn f(&self, w_bias: &M) -> f64 { fn f(&self, w_bias: &M) -> T {
let mut f = 0.; let mut f = T::zero();
let (n, _) = self.x.shape(); let (n, _) = self.x.shape();
for i in 0..n { for i in 0..n {
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i); let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
f += wx.ln_1pe() - (self.y[i] as f64) * wx; f = f + (wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx);
} }
f f
@@ -57,7 +61,7 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for BinaryObjectiveFunction<'a, M> {
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i); let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
let dyi = (self.y[i] as f64) - wx.sigmoid(); let dyi = (T::from(self.y[i]).unwrap()) - wx.sigmoid();
for j in 0..p { for j in 0..p {
g.set(0, j, g.get(0, j) - dyi * self.x.get(i, j)); g.set(0, j, g.get(0, j) - dyi * self.x.get(i, j));
} }
@@ -68,16 +72,17 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for BinaryObjectiveFunction<'a, M> {
} }
struct MultiClassObjectiveFunction<'a, M: Matrix> { struct MultiClassObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix<T>> {
x: &'a M, x: &'a M,
y: Vec<usize>, y: Vec<usize>,
k: usize k: usize,
phantom: PhantomData<&'a T>
} }
impl<'a, M: Matrix> ObjectiveFunction<M> for MultiClassObjectiveFunction<'a, M> { impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObjectiveFunction<'a, T, M> {
fn f(&self, w_bias: &M) -> f64 { fn f(&self, w_bias: &M) -> T {
let mut f = 0.; let mut f = T::zero();
let mut prob = M::zeros(1, self.k); let mut prob = M::zeros(1, self.k);
let (n, p) = self.x.shape(); let (n, p) = self.x.shape();
for i in 0..n { for i in 0..n {
@@ -85,7 +90,7 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for MultiClassObjectiveFunction<'a, M>
prob.set(0, j, MultiClassObjectiveFunction::partial_dot(w_bias, self.x, j * (p + 1), i)); prob.set(0, j, MultiClassObjectiveFunction::partial_dot(w_bias, self.x, j * (p + 1), i));
} }
prob.softmax_mut(); prob.softmax_mut();
f -= prob.get(0, self.y[i]).ln(); f = f - prob.get(0, self.y[i]).ln();
} }
f f
@@ -106,7 +111,7 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for MultiClassObjectiveFunction<'a, M>
prob.softmax_mut(); prob.softmax_mut();
for j in 0..self.k { for j in 0..self.k {
let yi =(if self.y[i] == j { 1.0 } else { 0.0 }) - prob.get(0, j); let yi =(if self.y[i] == j { T::one() } else { T::zero() }) - prob.get(0, j);
for l in 0..p { for l in 0..p {
let pos = j * (p + 1); let pos = j * (p + 1);
@@ -120,9 +125,9 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for MultiClassObjectiveFunction<'a, M>
} }
impl<M: Matrix> LogisticRegression<M> { impl<T: FloatExt + Debug, M: Matrix<T>> LogisticRegression<T, M> {
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<M>{ pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M>{
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
@@ -153,7 +158,8 @@ impl<M: Matrix> LogisticRegression<M> {
let objective = BinaryObjectiveFunction{ let objective = BinaryObjectiveFunction{
x: x, x: x,
y: yi y: yi,
phantom: PhantomData
}; };
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
@@ -172,7 +178,8 @@ impl<M: Matrix> LogisticRegression<M> {
let objective = MultiClassObjectiveFunction{ let objective = MultiClassObjectiveFunction{
x: x, x: x,
y: yi, y: yi,
k: k k: k,
phantom: PhantomData
}; };
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
@@ -196,9 +203,9 @@ impl<M: Matrix> LogisticRegression<M> {
if self.num_classes == 2 { if self.num_classes == 2 {
let (nrows, _) = x.shape(); let (nrows, _) = x.shape();
let x_and_bias = x.v_stack(&M::ones(nrows, 1)); let x_and_bias = x.v_stack(&M::ones(nrows, 1));
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector(); let y_hat: Vec<T> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
for i in 0..n { for i in 0..n {
result.set(0, i, self.classes[if y_hat[i].sigmoid() > 0.5 { 1 } else { 0 }]); result.set(0, i, self.classes[if y_hat[i].sigmoid() > T::half() { 1 } else { 0 }]);
} }
} else { } else {
@@ -221,8 +228,8 @@ impl<M: Matrix> LogisticRegression<M> {
self.weights.slice(0..self.num_classes, self.num_attributes..self.num_attributes+1) self.weights.slice(0..self.num_classes, self.num_attributes..self.num_attributes+1)
} }
fn minimize(x0: M, objective: impl ObjectiveFunction<M>) -> OptimizerResult<M> { fn minimize(x0: M, objective: impl ObjectiveFunction<T, M>) -> OptimizerResult<T, M> {
let f = |w: &M| -> f64 { let f = |w: &M| -> T {
objective.f(w) objective.f(w)
}; };
@@ -230,9 +237,9 @@ impl<M: Matrix> LogisticRegression<M> {
objective.df(g, w) objective.df(g, w)
}; };
let mut ls: Backtracking = Default::default(); let mut ls: Backtracking<T> = Default::default();
ls.order = FunctionOrder::THIRD; ls.order = FunctionOrder::THIRD;
let optimizer: LBFGS = Default::default(); let optimizer: LBFGS<T> = Default::default();
optimizer.optimize(&f, &df, &x0, &ls) optimizer.optimize(&f, &df, &x0, &ls)
} }
@@ -270,10 +277,11 @@ mod tests {
let objective = MultiClassObjectiveFunction{ let objective = MultiClassObjectiveFunction{
x: &x, x: &x,
y: y, y: y,
k: 3 k: 3,
phantom: PhantomData
}; };
let mut g = DenseMatrix::zeros(1, 9); let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.])); objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]));
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.])); objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]));
@@ -309,10 +317,11 @@ mod tests {
let objective = BinaryObjectiveFunction{ let objective = BinaryObjectiveFunction{
x: &x, x: &x,
y: y y: y,
phantom: PhantomData
}; };
let mut g = DenseMatrix::zeros(1, 3); let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3.])); objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3.]));
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3.])); objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3.]));
@@ -345,7 +354,7 @@ mod tests {
&[10., -2.], &[10., -2.],
&[ 8., 2.], &[ 8., 2.],
&[ 9., 0.]]); &[ 9., 0.]]);
let y = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.]; let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y); let lr = LogisticRegression::fit(&x, &y);
+7 -5
View File
@@ -1,15 +1,17 @@
pub fn distance(x: &Vec<f64>, y: &Vec<f64>) -> f64 { use crate::math::num::FloatExt;
pub fn distance<T: FloatExt>(x: &Vec<T>, y: &Vec<T>) -> T {
return squared_distance(x, y).sqrt(); return squared_distance(x, y).sqrt();
} }
pub fn squared_distance(x: &Vec<f64>,y: &Vec<f64>) -> f64 { pub fn squared_distance<T: FloatExt>(x: &Vec<T>,y: &Vec<T>) -> T {
if x.len() != y.len() { if x.len() != y.len() {
panic!("Input vector sizes are different."); panic!("Input vector sizes are different.");
} }
let mut sum = 0f64; let mut sum = T::zero();
for i in 0..x.len() { for i in 0..x.len() {
sum += (x[i] - y[i]).powf(2.); sum = sum + (x[i] - y[i]).powf(T::two());
} }
return sum; return sum;
@@ -25,7 +27,7 @@ mod tests {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.]; let b = vec![4., 5., 6.];
let d_arr = distance(&a, &b); let d_arr: f64 = distance(&a, &b);
assert!((d_arr - 5.19615242).abs() < 1e-8); assert!((d_arr - 5.19615242).abs() < 1e-8);
} }
+1 -43
View File
@@ -1,44 +1,2 @@
pub mod distance; pub mod distance;
pub mod num;
pub static EPSILON:f64 = 2.2204460492503131e-16_f64;
pub trait NumericExt {
fn ln_1pe(&self) -> f64;
fn sigmoid(&self) -> f64;
}
impl NumericExt for f64 {
fn ln_1pe(&self) -> f64{
if *self > 15. {
return *self;
} else {
return self.exp().ln_1p();
}
}
fn sigmoid(&self) -> f64 {
if *self < -40. {
return 0.;
} else if *self > 40. {
return 1.;
} else {
return 1. / (1. + f64::exp(-self))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sigmoid() {
assert_eq!(1.0.sigmoid(), 0.7310585786300049);
assert_eq!(41.0.sigmoid(), 1.);
assert_eq!((-41.0).sigmoid(), 0.);
}
}
+110
View File
@@ -0,0 +1,110 @@
use num_traits::{Float, FromPrimitive};
use rand::prelude::*;
pub trait FloatExt: Float + FromPrimitive {
fn copysign(self, sign: Self) -> Self;
fn ln_1pe(self) -> Self;
fn sigmoid(self) -> Self;
fn rand() -> Self;
fn two() -> Self;
fn half() -> Self;
}
impl FloatExt for f64 {
fn copysign(self, sign: Self) -> Self{
self.copysign(sign)
}
fn ln_1pe(self) -> f64{
if self > 15. {
return self;
} else {
return self.exp().ln_1p();
}
}
fn sigmoid(self) -> f64 {
if self < -40. {
return 0.;
} else if self > 40. {
return 1.;
} else {
return 1. / (1. + f64::exp(-self))
}
}
fn rand() -> f64 {
let mut rng = rand::thread_rng();
rng.gen()
}
fn two() -> Self {
2f64
}
fn half() -> Self {
0.5f64
}
}
impl FloatExt for f32 {
fn copysign(self, sign: Self) -> Self{
self.copysign(sign)
}
fn ln_1pe(self) -> f32{
if self > 15. {
return self;
} else {
return self.exp().ln_1p();
}
}
fn sigmoid(self) -> f32 {
if self < -40. {
return 0.;
} else if self > 40. {
return 1.;
} else {
return 1. / (1. + f32::exp(-self))
}
}
fn rand() -> f32 {
let mut rng = rand::thread_rng();
rng.gen()
}
fn two() -> Self {
2f32
}
fn half() -> Self {
0.5f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sigmoid() {
assert_eq!(1.0.sigmoid(), 0.7310585786300049);
assert_eq!(41.0.sigmoid(), 1.);
assert_eq!((-41.0).sigmoid(), 0.);
}
}
+13 -13
View File
@@ -1,21 +1,21 @@
use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::{Matrix, row_iter}; use crate::linalg::{Matrix, row_iter};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
use crate::algorithm::neighbour::linear_search::LinearKNNSearch; use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::algorithm::neighbour::cover_tree::CoverTree; use crate::algorithm::neighbour::cover_tree::CoverTree;
pub struct KNNClassifier<'a, T: FloatExt> {
type F = dyn Fn(&Vec<f64>, &Vec<f64>) -> f64; classes: Vec<T>,
pub struct KNNClassifier<'a> {
classes: Vec<f64>,
y: Vec<usize>, y: Vec<usize>,
knn_algorithm: Box<dyn KNNAlgorithm<Vec<f64>> + 'a>, knn_algorithm: Box<dyn KNNAlgorithm<Vec<T>> + 'a>,
k: usize, k: usize,
} }
impl<'a> KNNClassifier<'a> { impl<'a, T: FloatExt + Debug> KNNClassifier<'a, T> {
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, k: usize, distance: &'a F, algorithm: KNNAlgorithmName) -> KNNClassifier<'a> { pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, k: usize, distance: &'a dyn Fn(&Vec<T>, &Vec<T>) -> T, algorithm: KNNAlgorithmName) -> KNNClassifier<'a, T> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
@@ -36,16 +36,16 @@ impl<'a> KNNClassifier<'a> {
assert!(k > 1, format!("k should be > 1, k=[{}]", k)); assert!(k > 1, format!("k should be > 1, k=[{}]", k));
let knn_algorithm: Box<dyn KNNAlgorithm<Vec<f64>> + 'a> = match algorithm { let knn_algorithm: Box<dyn KNNAlgorithm<Vec<T>> + 'a> = match algorithm {
KNNAlgorithmName::CoverTree => Box::new(CoverTree::<Vec<f64>>::new(data, distance)), KNNAlgorithmName::CoverTree => Box::new(CoverTree::<Vec<T>, T>::new(data, distance)),
KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::<Vec<f64>>::new(data, distance)) KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::<Vec<T>, T>::new(data, distance))
}; };
KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: knn_algorithm} KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: knn_algorithm}
} }
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
row_iter(x).enumerate().for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)])); row_iter(x).enumerate().for_each(|(i, x)| result.set(0, i, self.classes[self.predict_for_row(x)]));
@@ -53,7 +53,7 @@ impl<'a> KNNClassifier<'a> {
result.to_row_vector() result.to_row_vector()
} }
fn predict_for_row(&self, x: Vec<f64>) -> usize { fn predict_for_row(&self, x: Vec<T>) -> usize {
let idxs = self.knn_algorithm.find(&x, self.k); let idxs = self.knn_algorithm.find(&x, self.k);
let mut c = vec![0; self.classes.len()]; let mut c = vec![0; self.classes.len()];
@@ -1,30 +1,32 @@
use std::default::Default; use std::default::Default;
use crate::math::EPSILON; use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::optimization::{F, DF}; use crate::optimization::{F, DF};
use crate::optimization::line_search::LineSearchMethod; use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult}; use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
pub struct GradientDescent { pub struct GradientDescent<T: FloatExt> {
pub max_iter: usize, pub max_iter: usize,
pub g_rtol: f64, pub g_rtol: T,
pub g_atol: f64 pub g_atol: T
} }
impl Default for GradientDescent { impl<T: FloatExt> Default for GradientDescent<T> {
fn default() -> Self { fn default() -> Self {
GradientDescent { GradientDescent {
max_iter: 10000, max_iter: 10000,
g_rtol: EPSILON.sqrt(), g_rtol: T::epsilon().sqrt(),
g_atol: EPSILON g_atol: T::epsilon()
} }
} }
} }
impl FirstOrderOptimizer for GradientDescent impl<T: FloatExt + Debug> FirstOrderOptimizer<T> for GradientDescent<T>
{ {
fn optimize<'a, X: Matrix, LS: LineSearchMethod>(&self, f: &'a F<X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<X> { fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &'a F<T, X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<T, X> {
let mut x = x0.clone(); let mut x = x0.clone();
let mut fx = f(&x); let mut fx = f(&x);
@@ -35,7 +37,7 @@ impl FirstOrderOptimizer for GradientDescent
let gtol = (gvec.norm2() * self.g_rtol).max(self.g_atol); let gtol = (gvec.norm2() * self.g_rtol).max(self.g_atol);
let mut iter = 0; let mut iter = 0;
let mut alpha = 1.0; let mut alpha = T::one();
df(&mut gvec, &x); df(&mut gvec, &x);
while iter < self.max_iter && (iter == 0 || gnorm > gtol) { while iter < self.max_iter && (iter == 0 || gnorm > gtol) {
@@ -43,13 +45,13 @@ impl FirstOrderOptimizer for GradientDescent
let mut step = gvec.negative(); let mut step = gvec.negative();
let f_alpha = |alpha: f64| -> f64 { let f_alpha = |alpha: T| -> T {
let mut dx = step.clone(); let mut dx = step.clone();
dx.mul_scalar_mut(alpha); dx.mul_scalar_mut(alpha);
f(&dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha) f(&dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha)
}; };
let df_alpha = |alpha: f64| -> f64 { let df_alpha = |alpha: T| -> T {
let mut dx = step.clone(); let mut dx = step.clone();
let mut dg = gvec.clone(); let mut dg = gvec.clone();
dx.mul_scalar_mut(alpha); dx.mul_scalar_mut(alpha);
@@ -88,18 +90,18 @@ mod tests {
fn gradient_descent() { fn gradient_descent() {
let x0 = DenseMatrix::vector_from_array(&[-1., 1.]); let x0 = DenseMatrix::vector_from_array(&[-1., 1.]);
let f = |x: &DenseMatrix| { let f = |x: &DenseMatrix<f64>| {
(1.0 - x.get(0, 0)).powf(2.) + 100.0 * (x.get(0, 1) - x.get(0, 0).powf(2.)).powf(2.) (1.0 - x.get(0, 0)).powf(2.) + 100.0 * (x.get(0, 1) - x.get(0, 0).powf(2.)).powf(2.)
}; };
let df = |g: &mut DenseMatrix, x: &DenseMatrix| { let df = |g: &mut DenseMatrix<f64>, x: &DenseMatrix<f64>| {
g.set(0, 0, -2. * (1. - x.get(0, 0)) - 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0)); g.set(0, 0, -2. * (1. - x.get(0, 0)) - 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0));
g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.))); g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.)));
}; };
let mut ls: Backtracking = Default::default(); let mut ls: Backtracking<f64> = Default::default();
ls.order = FunctionOrder::THIRD; ls.order = FunctionOrder::THIRD;
let optimizer: GradientDescent = Default::default(); let optimizer: GradientDescent<f64> = Default::default();
let result = optimizer.optimize(&f, &df, &x0, &ls); let result = optimizer.optimize(&f, &df, &x0, &ls);
+51 -50
View File
@@ -1,41 +1,43 @@
use std::default::Default; use std::default::Default;
use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::optimization::{F, DF}; use crate::optimization::{F, DF};
use crate::optimization::line_search::LineSearchMethod; use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult}; use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
use std::fmt::Debug;
pub struct LBFGS { pub struct LBFGS<T: FloatExt> {
pub max_iter: usize, pub max_iter: usize,
pub g_rtol: f64, pub g_rtol: T,
pub g_atol: f64, pub g_atol: T,
pub x_atol: f64, pub x_atol: T,
pub x_rtol: f64, pub x_rtol: T,
pub f_abstol: f64, pub f_abstol: T,
pub f_reltol: f64, pub f_reltol: T,
pub successive_f_tol: usize, pub successive_f_tol: usize,
pub m: usize pub m: usize
} }
impl Default for LBFGS { impl<T: FloatExt> Default for LBFGS<T> {
fn default() -> Self { fn default() -> Self {
LBFGS { LBFGS {
max_iter: 1000, max_iter: 1000,
g_rtol: 1e-8, g_rtol: T::from(1e-8).unwrap(),
g_atol: 1e-8, g_atol: T::from(1e-8).unwrap(),
x_atol: 0., x_atol: T::zero(),
x_rtol: 0., x_rtol: T::zero(),
f_abstol: 0., f_abstol: T::zero(),
f_reltol: 0., f_reltol: T::zero(),
successive_f_tol: 1, successive_f_tol: 1,
m: 10 m: 10
} }
} }
} }
impl LBFGS { impl<T: FloatExt + Debug> LBFGS<T> {
fn two_loops<X: Matrix>(&self, state: &mut LBFGSState<X>) { fn two_loops<X: Matrix<T>>(&self, state: &mut LBFGSState<T, X>) {
let lower = state.iteration.max(self.m) - self.m; let lower = state.iteration.max(self.m) - self.m;
let upper = state.iteration; let upper = state.iteration;
@@ -54,7 +56,7 @@ impl LBFGS {
let i = (upper - 1).rem_euclid(self.m); let i = (upper - 1).rem_euclid(self.m);
let dxi = &state.dx_history[i]; let dxi = &state.dx_history[i];
let dgi = &state.dg_history[i]; let dgi = &state.dg_history[i];
let scaling = dxi.vector_dot(dgi) / dgi.abs().pow_mut(2.).sum(); let scaling = dxi.vector_dot(dgi) / dgi.abs().pow_mut(T::two()).sum();
state.s.copy_from(&state.twoloop_q.mul_scalar(scaling)); state.s.copy_from(&state.twoloop_q.mul_scalar(scaling));
} else { } else {
state.s.copy_from(&state.twoloop_q); state.s.copy_from(&state.twoloop_q);
@@ -68,34 +70,34 @@ impl LBFGS {
state.s.add_mut(&dxi.mul_scalar(state.twoloop_alpha[i] - beta)); state.s.add_mut(&dxi.mul_scalar(state.twoloop_alpha[i] - beta));
} }
state.s.mul_scalar_mut(-1.); state.s.mul_scalar_mut(-T::one());
} }
fn init_state<X: Matrix>(&self, x: &X) -> LBFGSState<X> { fn init_state<X: Matrix<T>>(&self, x: &X) -> LBFGSState<T, X> {
LBFGSState { LBFGSState {
x: x.clone(), x: x.clone(),
x_prev: x.clone(), x_prev: x.clone(),
x_f: std::f64::NAN, x_f: T::nan(),
x_f_prev: std::f64::NAN, x_f_prev: T::nan(),
x_df: x.clone(), x_df: x.clone(),
x_df_prev: x.clone(), x_df_prev: x.clone(),
rho: vec![0.; self.m], rho: vec![T::zero(); self.m],
dx_history: vec![x.clone(); self.m], dx_history: vec![x.clone(); self.m],
dg_history: vec![x.clone(); self.m], dg_history: vec![x.clone(); self.m],
dx: x.clone(), dx: x.clone(),
dg: x.clone(), dg: x.clone(),
twoloop_q: x.clone(), twoloop_q: x.clone(),
twoloop_alpha: vec![0.; self.m], twoloop_alpha: vec![T::zero(); self.m],
iteration: 0, iteration: 0,
counter_f_tol: 0, counter_f_tol: 0,
s: x.clone(), s: x.clone(),
alpha: 1.0 alpha: T::one()
} }
} }
fn update_state<'a, X: Matrix, LS: LineSearchMethod>(&self, f: &'a F<X>, df: &'a DF<X>, ls: &'a LS, state: &mut LBFGSState<X>) { fn update_state<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &'a F<T, X>, df: &'a DF<X>, ls: &'a LS, state: &mut LBFGSState<T, X>) {
self.two_loops(state); self.two_loops(state);
df(&mut state.x_df_prev, &state.x); df(&mut state.x_df_prev, &state.x);
@@ -104,13 +106,13 @@ impl LBFGS {
let df0 = state.x_df.vector_dot(&state.s); let df0 = state.x_df.vector_dot(&state.s);
let f_alpha = |alpha: f64| -> f64 { let f_alpha = |alpha: T| -> T {
let mut dx = state.s.clone(); let mut dx = state.s.clone();
dx.mul_scalar_mut(alpha); dx.mul_scalar_mut(alpha);
f(&dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha) f(&dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha)
}; };
let df_alpha = |alpha: f64| -> f64 { let df_alpha = |alpha: T| -> T {
let mut dx = state.s.clone(); let mut dx = state.s.clone();
let mut dg = state.x_df.clone(); let mut dg = state.x_df.clone();
dx.mul_scalar_mut(alpha); dx.mul_scalar_mut(alpha);
@@ -118,7 +120,7 @@ impl LBFGS {
state.x_df.vector_dot(&dg) state.x_df.vector_dot(&dg)
}; };
let ls_r = ls.search(&f_alpha, &df_alpha, 1.0, state.x_f_prev, df0); let ls_r = ls.search(&f_alpha, &df_alpha, T::one(), state.x_f_prev, df0);
state.alpha = ls_r.alpha; state.alpha = ls_r.alpha;
state.dx.copy_from(state.s.mul_scalar_mut(state.alpha)); state.dx.copy_from(state.s.mul_scalar_mut(state.alpha));
@@ -128,14 +130,14 @@ impl LBFGS {
} }
fn assess_convergence<X: Matrix>(&self, state: &mut LBFGSState<X>) -> bool { fn assess_convergence<X: Matrix<T>>(&self, state: &mut LBFGSState<T, X>) -> bool {
let (mut x_converged, mut g_converged) = (false, false); let (mut x_converged, mut g_converged) = (false, false);
if state.x.max_diff(&state.x_prev) <= self.x_atol { if state.x.max_diff(&state.x_prev) <= self.x_atol {
x_converged = true; x_converged = true;
} }
if state.x.max_diff(&state.x_prev) <= self.x_rtol * state.x.norm(std::f64::INFINITY) { if state.x.max_diff(&state.x_prev) <= self.x_rtol * state.x.norm(T::infinity()) {
x_converged = true; x_converged = true;
} }
@@ -147,16 +149,16 @@ impl LBFGS {
state.counter_f_tol += 1; state.counter_f_tol += 1;
} }
if state.x_df.norm(std::f64::INFINITY) <= self.g_atol { if state.x_df.norm(T::infinity()) <= self.g_atol {
g_converged = true; g_converged = true;
} }
g_converged || x_converged || state.counter_f_tol > self.successive_f_tol g_converged || x_converged || state.counter_f_tol > self.successive_f_tol
} }
fn update_hessian<'a, X: Matrix>(&self, _: &'a DF<X>, state: &mut LBFGSState<X>) { fn update_hessian<'a, X: Matrix<T>>(&self, _: &'a DF<X>, state: &mut LBFGSState<T, X>) {
state.dg = state.x_df.sub(&state.x_df_prev); state.dg = state.x_df.sub(&state.x_df_prev);
let rho_iteration = 1. / state.dx.vector_dot(&state.dg); let rho_iteration = T::one() / state.dx.vector_dot(&state.dg);
if !rho_iteration.is_infinite() { if !rho_iteration.is_infinite() {
let idx = state.iteration.rem_euclid(self.m); let idx = state.iteration.rem_euclid(self.m);
state.dx_history[idx].copy_from(&state.dx); state.dx_history[idx].copy_from(&state.dx);
@@ -167,35 +169,35 @@ impl LBFGS {
} }
#[derive(Debug)] #[derive(Debug)]
struct LBFGSState<X: Matrix> { struct LBFGSState<T: FloatExt + Debug, X: Matrix<T>> {
x: X, x: X,
x_prev: X, x_prev: X,
x_f: f64, x_f: T,
x_f_prev: f64, x_f_prev: T,
x_df: X, x_df: X,
x_df_prev: X, x_df_prev: X,
rho: Vec<f64>, rho: Vec<T>,
dx_history: Vec<X>, dx_history: Vec<X>,
dg_history: Vec<X>, dg_history: Vec<X>,
dx: X, dx: X,
dg: X, dg: X,
twoloop_q: X, twoloop_q: X,
twoloop_alpha: Vec<f64>, twoloop_alpha: Vec<T>,
iteration: usize, iteration: usize,
counter_f_tol: usize, counter_f_tol: usize,
s: X, s: X,
alpha: f64 alpha: T
} }
impl FirstOrderOptimizer for LBFGS { impl<T: FloatExt + Debug> FirstOrderOptimizer<T> for LBFGS<T> {
fn optimize<'a, X: Matrix, LS: LineSearchMethod>(&self, f: &F<X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<X> { fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &F<T, X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<T, X> {
let mut state = self.init_state(x0); let mut state = self.init_state(x0);
df(&mut state.x_df, &x0); df(&mut state.x_df, &x0);
let g_converged = state.x_df.norm(std::f64::INFINITY) < self.g_atol; let g_converged = state.x_df.norm(T::infinity()) < self.g_atol;
let mut converged = g_converged; let mut converged = g_converged;
let stopped = false; let stopped = false;
@@ -228,27 +230,26 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::optimization::line_search::Backtracking; use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
use crate::math::EPSILON;
#[test] #[test]
fn lbfgs() { fn lbfgs() {
let x0 = DenseMatrix::vector_from_array(&[0., 0.]); let x0 = DenseMatrix::vector_from_array(&[0., 0.]);
let f = |x: &DenseMatrix| { let f = |x: &DenseMatrix<f64>| {
(1.0 - x.get(0, 0)).powf(2.) + 100.0 * (x.get(0, 1) - x.get(0, 0).powf(2.)).powf(2.) (1.0 - x.get(0, 0)).powf(2.) + 100.0 * (x.get(0, 1) - x.get(0, 0).powf(2.)).powf(2.)
}; };
let df = |g: &mut DenseMatrix, x: &DenseMatrix| { let df = |g: &mut DenseMatrix<f64>, x: &DenseMatrix<f64>| {
g.set(0, 0, -2. * (1. - x.get(0, 0)) - 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0)); g.set(0, 0, -2. * (1. - x.get(0, 0)) - 400. * (x.get(0, 1) - x.get(0, 0).powf(2.)) * x.get(0, 0));
g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.))); g.set(0, 1, 200. * (x.get(0, 1) - x.get(0, 0).powf(2.)));
}; };
let mut ls: Backtracking = Default::default(); let mut ls: Backtracking<f64> = Default::default();
ls.order = FunctionOrder::THIRD; ls.order = FunctionOrder::THIRD;
let optimizer: LBFGS = Default::default(); let optimizer: LBFGS<f64> = Default::default();
let result = optimizer.optimize(&f, &df, &x0, &ls); let result = optimizer.optimize(&f, &df, &x0, &ls);
assert!((result.f_x - 0.0).abs() < EPSILON); assert!((result.f_x - 0.0).abs() < std::f64::EPSILON);
assert!((result.x.get(0, 0) - 1.0).abs() < 1e-8); assert!((result.x.get(0, 0) - 1.0).abs() < 1e-8);
assert!((result.x.get(0, 1) - 1.0).abs() < 1e-8); assert!((result.x.get(0, 1) - 1.0).abs() < 1e-8);
assert!(result.iterations <= 24); assert!(result.iterations <= 24);
+9 -5
View File
@@ -1,18 +1,22 @@
pub mod lbfgs; pub mod lbfgs;
pub mod gradient_descent; pub mod gradient_descent;
use std::clone::Clone;
use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::optimization::line_search::LineSearchMethod; use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::{F, DF}; use crate::optimization::{F, DF};
pub trait FirstOrderOptimizer { pub trait FirstOrderOptimizer<T: FloatExt + Debug> {
fn optimize<'a, X: Matrix, LS: LineSearchMethod>(&self, f: &F<X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<X>; fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(&self, f: &F<T, X>, df: &'a DF<X>, x0: &X, ls: &'a LS) -> OptimizerResult<T, X>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct OptimizerResult<X> pub struct OptimizerResult<T: FloatExt + Debug, X: Matrix<T>>
where X: Matrix
{ {
pub x: X, pub x: X,
pub f_x: f64, pub f_x: T,
pub iterations: usize pub iterations: usize
} }
+31 -28
View File
@@ -1,41 +1,44 @@
use crate::math::EPSILON; use num_traits::Float;
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
pub trait LineSearchMethod { pub trait LineSearchMethod<T: Float> {
fn search<'a>(&self, f: &(dyn Fn(f64) -> f64), df: &(dyn Fn(f64) -> f64), alpha: f64, f0: f64, df0: f64) -> LineSearchResult; fn search<'a>(&self, f: &(dyn Fn(T) -> T), df: &(dyn Fn(T) -> T), alpha: T, f0: T, df0: T) -> LineSearchResult<T>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LineSearchResult { pub struct LineSearchResult<T: Float> {
pub alpha: f64, pub alpha: T,
pub f_x: f64 pub f_x: T
} }
pub struct Backtracking { pub struct Backtracking<T: Float> {
pub c1: f64, pub c1: T,
pub max_iterations: usize, pub max_iterations: usize,
pub max_infinity_iterations: usize, pub max_infinity_iterations: usize,
pub phi: f64, pub phi: T,
pub plo: f64, pub plo: T,
pub order: FunctionOrder pub order: FunctionOrder
} }
impl Default for Backtracking { impl<T: Float> Default for Backtracking<T> {
fn default() -> Self { fn default() -> Self {
Backtracking { Backtracking {
c1: 1e-4, c1: T::from(1e-4).unwrap(),
max_iterations: 1000, max_iterations: 1000,
max_infinity_iterations: -EPSILON.log2() as usize, max_infinity_iterations: (-T::epsilon().log2()).to_usize().unwrap(),
phi: 0.5, phi: T::from(0.5).unwrap(),
plo: 0.1, plo: T::from(0.1).unwrap(),
order: FunctionOrder::SECOND order: FunctionOrder::SECOND
} }
} }
} }
impl LineSearchMethod for Backtracking { impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
fn search<'a>(&self, f: &(dyn Fn(f64) -> f64), _: &(dyn Fn(f64) -> f64), alpha: f64, f0: f64, df0: f64) -> LineSearchResult { fn search<'a>(&self, f: &(dyn Fn(T) -> T), _: &(dyn Fn(T) -> T), alpha: T, f0: T, df0: T) -> LineSearchResult<T> {
let two = T::from(2.).unwrap();
let three = T::from(3.).unwrap();
let (mut a1, mut a2) = (alpha, alpha); let (mut a1, mut a2) = (alpha, alpha);
let (mut fx0, mut fx1) = (f0, f(a1)); let (mut fx0, mut fx1) = (f0, f(a1));
@@ -44,7 +47,7 @@ impl LineSearchMethod for Backtracking {
while !fx1.is_finite() && iterfinite < self.max_infinity_iterations { while !fx1.is_finite() && iterfinite < self.max_infinity_iterations {
iterfinite += 1; iterfinite += 1;
a1 = a2; a1 = a2;
a2 = a1 / 2.; a2 = a1 / two;
fx1 = f(a2); fx1 = f(a2);
} }
@@ -60,24 +63,24 @@ impl LineSearchMethod for Backtracking {
if self.order == FunctionOrder::SECOND || iteration == 0 { if self.order == FunctionOrder::SECOND || iteration == 0 {
a_tmp = - (df0 * a2.powf(2.)) / (2. * (fx1 - f0 - df0*a2)) a_tmp = - (df0 * a2.powf(two)) / (two * (fx1 - f0 - df0*a2))
} else { } else {
let div = 1. / (a1.powf(2.) * a2.powf(2.) * (a2 - a1)); let div = T::one() / (a1.powf(two) * a2.powf(two) * (a2 - a1));
let a = (a1.powf(2.) * (fx1 - f0 - df0*a2) - a2.powf(2.)*(fx0 - f0 - df0*a1))*div; let a = (a1.powf(two) * (fx1 - f0 - df0*a2) - a2.powf(two)*(fx0 - f0 - df0*a1))*div;
let b = (-a1.powf(3.) * (fx1 - f0 - df0*a2) + a2.powf(3.)*(fx0 - f0 - df0*a1))*div; let b = (-a1.powf(three) * (fx1 - f0 - df0*a2) + a2.powf(three)*(fx0 - f0 - df0*a1))*div;
if (a - 0.).powf(2.).sqrt() <= EPSILON { if (a - T::zero()).powf(two).sqrt() <= T::epsilon() {
a_tmp = df0 / (2. * b); a_tmp = df0 / (two * b);
} else { } else {
let d = f64::max(b.powf(2.) - 3. * a * df0, 0.); let d = T::max(b.powf(two) - three * a * df0, T::zero());
a_tmp = (-b + d.sqrt()) / (3.*a); //root of quadratic equation a_tmp = (-b + d.sqrt()) / (three*a); //root of quadratic equation
} }
} }
a1 = a2; a1 = a2;
a2 = f64::max(f64::min(a_tmp, a2*self.phi), a2*self.plo); a2 = T::max(T::min(a_tmp, a2*self.phi), a2*self.plo);
fx0 = fx1; fx0 = fx1;
fx1 = f(a2); fx1 = f(a2);
@@ -108,7 +111,7 @@ mod tests {
2. * x + 1. 2. * x + 1.
}; };
let ls: Backtracking = Default::default(); let ls: Backtracking<f64> = Default::default();
let mut x = -3.; let mut x = -3.;
let mut alpha = 1.; let mut alpha = 1.;
+1 -1
View File
@@ -1,7 +1,7 @@
pub mod first_order; pub mod first_order;
pub mod line_search; pub mod line_search;
pub type F<'a, X> = dyn for<'b> Fn(&'b X) -> f64 + 'a; pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a;
pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a; pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
+49 -43
View File
@@ -1,5 +1,9 @@
use std::default::Default; use std::default::Default;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::collections::LinkedList; use std::collections::LinkedList;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::algorithm::sort::quick_sort::QuickArgSort; use crate::algorithm::sort::quick_sort::QuickArgSort;
@@ -12,11 +16,11 @@ pub struct DecisionTreeClassifierParameters {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct DecisionTreeClassifier { pub struct DecisionTreeClassifier<T: FloatExt> {
nodes: Vec<Node>, nodes: Vec<Node<T>>,
parameters: DecisionTreeClassifierParameters, parameters: DecisionTreeClassifierParameters,
num_classes: usize, num_classes: usize,
classes: Vec<f64>, classes: Vec<T>,
depth: u16 depth: u16
} }
@@ -28,12 +32,12 @@ pub enum SplitCriterion {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct Node { pub struct Node<T: FloatExt> {
index: usize, index: usize,
output: usize, output: usize,
split_feature: usize, split_feature: usize,
split_value: f64, split_value: T,
split_score: f64, split_score: T,
true_child: Option<usize>, true_child: Option<usize>,
false_child: Option<usize>, false_child: Option<usize>,
} }
@@ -50,21 +54,21 @@ impl Default for DecisionTreeClassifierParameters {
} }
} }
impl Node { impl<T: FloatExt> Node<T> {
fn new(index: usize, output: usize) -> Self { fn new(index: usize, output: usize) -> Self {
Node { Node {
index: index, index: index,
output: output, output: output,
split_feature: 0, split_feature: 0,
split_value: std::f64::NAN, split_value: T::nan(),
split_score: std::f64::NAN, split_score: T::nan(),
true_child: Option::None, true_child: Option::None,
false_child: Option::None false_child: Option::None
} }
} }
} }
struct NodeVisitor<'a, M: Matrix> { struct NodeVisitor<'a, T: FloatExt + Debug, M: Matrix<T>> {
x: &'a M, x: &'a M,
y: &'a Vec<usize>, y: &'a Vec<usize>,
node: usize, node: usize,
@@ -72,19 +76,20 @@ struct NodeVisitor<'a, M: Matrix> {
order: &'a Vec<Vec<usize>>, order: &'a Vec<Vec<usize>>,
true_child_output: usize, true_child_output: usize,
false_child_output: usize, false_child_output: usize,
level: u16 level: u16,
phantom: PhantomData<&'a T>
} }
fn impurity(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> f64 { fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> T {
let mut impurity = 0.; let mut impurity = T::zero();
match criterion { match criterion {
SplitCriterion::Gini => { SplitCriterion::Gini => {
impurity = 1.0; impurity = T::one();
for i in 0..count.len() { for i in 0..count.len() {
if count[i] > 0 { if count[i] > 0 {
let p = count[i] as f64 / n as f64; let p = T::from(count[i]).unwrap() / T::from(n).unwrap();
impurity -= p * p; impurity = impurity - p * p;
} }
} }
} }
@@ -92,25 +97,25 @@ fn impurity(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> f64 {
SplitCriterion::Entropy => { SplitCriterion::Entropy => {
for i in 0..count.len() { for i in 0..count.len() {
if count[i] > 0 { if count[i] > 0 {
let p = count[i] as f64 / n as f64; let p = T::from(count[i]).unwrap() / T::from(n).unwrap();
impurity -= p * p.log2(); impurity = impurity - p * p.log2();
} }
} }
} }
SplitCriterion::ClassificationError => { SplitCriterion::ClassificationError => {
for i in 0..count.len() { for i in 0..count.len() {
if count[i] > 0 { if count[i] > 0 {
impurity = impurity.max(count[i] as f64 / n as f64); impurity = impurity.max(T::from(count[i]).unwrap() / T::from(n).unwrap());
} }
} }
impurity = (1. - impurity).abs(); impurity = (T::one() - impurity).abs();
} }
} }
return impurity; return impurity;
} }
impl<'a, M: Matrix> NodeVisitor<'a, M> { impl<'a, T: FloatExt + Debug, M: Matrix<T>> NodeVisitor<'a, T, M> {
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self { fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self {
NodeVisitor { NodeVisitor {
@@ -121,7 +126,8 @@ impl<'a, M: Matrix> NodeVisitor<'a, M> {
order: order, order: order,
true_child_output: 0, true_child_output: 0,
false_child_output: 0, false_child_output: 0,
level: level level: level,
phantom: PhantomData
} }
} }
@@ -141,15 +147,15 @@ pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
return which; return which;
} }
impl DecisionTreeClassifier { impl<T: FloatExt + Debug> DecisionTreeClassifier<T> {
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier { pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier<T> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
} }
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier { pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier<T> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
@@ -166,7 +172,7 @@ impl DecisionTreeClassifier {
yi[i] = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap();
} }
let mut nodes: Vec<Node> = Vec::new(); let mut nodes: Vec<Node<T>> = Vec::new();
let mut count = vec![0; k]; let mut count = vec![0; k];
for i in 0..y_ncols { for i in 0..y_ncols {
@@ -189,9 +195,9 @@ impl DecisionTreeClassifier {
depth: 0 depth: 0
}; };
let mut visitor = NodeVisitor::<M>::new(0, samples, &order, &x, &yi, 1); let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &yi, 1);
let mut visitor_queue: LinkedList<NodeVisitor<M>> = LinkedList::new(); let mut visitor_queue: LinkedList<NodeVisitor<T, M>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry) { if tree.find_best_cutoff(&mut visitor, mtry) {
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
@@ -207,7 +213,7 @@ impl DecisionTreeClassifier {
tree tree
} }
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -219,7 +225,7 @@ impl DecisionTreeClassifier {
result.to_row_vector() result.to_row_vector()
} }
pub(in crate) fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> usize { pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
let mut result = 0; let mut result = 0;
let mut queue: LinkedList<usize> = LinkedList::new(); let mut queue: LinkedList<usize> = LinkedList::new();
@@ -247,7 +253,7 @@ impl DecisionTreeClassifier {
} }
fn find_best_cutoff<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, mtry: usize) -> bool { fn find_best_cutoff<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, mtry: usize) -> bool {
let (n_rows, n_attr) = visitor.x.shape(); let (n_rows, n_attr) = visitor.x.shape();
@@ -297,10 +303,10 @@ impl DecisionTreeClassifier {
} }
fn find_best_split<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, n: usize, count: &Vec<usize>, false_count: &mut Vec<usize>, parent_impurity: f64, j: usize){ fn find_best_split<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, n: usize, count: &Vec<usize>, false_count: &mut Vec<usize>, parent_impurity: T, j: usize){
let mut true_count = vec![0; self.num_classes]; let mut true_count = vec![0; self.num_classes];
let mut prevx = std::f64::NAN; let mut prevx = T::nan();
let mut prevy = 0; let mut prevy = 0;
for i in visitor.order[j].iter() { for i in visitor.order[j].iter() {
@@ -328,11 +334,11 @@ impl DecisionTreeClassifier {
let true_label = which_max(&true_count); let true_label = which_max(&true_count);
let false_label = which_max(false_count); let false_label = which_max(false_count);
let gain = parent_impurity - tc as f64 / n as f64 * impurity(&self.parameters.criterion, &true_count, tc) - fc as f64 / n as f64 * impurity(&self.parameters.criterion, &false_count, fc); let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score { if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
self.nodes[visitor.node].split_feature = j; self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / 2.; self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
self.nodes[visitor.node].split_score = gain; self.nodes[visitor.node].split_score = gain;
visitor.true_child_output = true_label; visitor.true_child_output = true_label;
visitor.false_child_output = false_label; visitor.false_child_output = false_label;
@@ -346,7 +352,7 @@ impl DecisionTreeClassifier {
} }
fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, M>>) -> bool { fn split<'a, M: Matrix<T>>(&mut self, mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>) -> bool {
let (n, _) = visitor.x.shape(); let (n, _) = visitor.x.shape();
let mut tc = 0; let mut tc = 0;
let mut fc = 0; let mut fc = 0;
@@ -366,8 +372,8 @@ impl DecisionTreeClassifier {
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf { if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
self.nodes[visitor.node].split_feature = 0; self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = std::f64::NAN; self.nodes[visitor.node].split_value = T::nan();
self.nodes[visitor.node].split_score = std::f64::NAN; self.nodes[visitor.node].split_score = T::nan();
return false; return false;
} }
@@ -381,13 +387,13 @@ impl DecisionTreeClassifier {
self.depth = u16::max(self.depth, visitor.level + 1); self.depth = u16::max(self.depth, visitor.level + 1);
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); let mut true_visitor = NodeVisitor::<T, M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
if self.find_best_cutoff(&mut true_visitor, mtry) { if self.find_best_cutoff(&mut true_visitor, mtry) {
visitor_queue.push_back(true_visitor); visitor_queue.push_back(true_visitor);
} }
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); let mut false_visitor = NodeVisitor::<T, M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
if self.find_best_cutoff(&mut false_visitor, mtry) { if self.find_best_cutoff(&mut false_visitor, mtry) {
visitor_queue.push_back(false_visitor); visitor_queue.push_back(false_visitor);
@@ -405,9 +411,9 @@ mod tests {
#[test] #[test]
fn gini_impurity() { fn gini_impurity() {
assert!((impurity(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs() < std::f64::EPSILON); assert!((impurity::<f64>(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs() < std::f64::EPSILON);
assert!((impurity(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs() < std::f64::EPSILON); assert!((impurity::<f64>(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs() < std::f64::EPSILON);
assert!((impurity(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs() < std::f64::EPSILON); assert!((impurity::<f64>(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs() < std::f64::EPSILON);
} }
#[test] #[test]
+50 -47
View File
@@ -1,5 +1,8 @@
use std::default::Default; use std::default::Default;
use std::fmt::Debug;
use std::collections::LinkedList; use std::collections::LinkedList;
use crate::math::num::FloatExt;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::algorithm::sort::quick_sort::QuickArgSort; use crate::algorithm::sort::quick_sort::QuickArgSort;
@@ -11,19 +14,19 @@ pub struct DecisionTreeRegressorParameters {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct DecisionTreeRegressor { pub struct DecisionTreeRegressor<T: FloatExt> {
nodes: Vec<Node>, nodes: Vec<Node<T>>,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
depth: u16 depth: u16
} }
#[derive(Debug)] #[derive(Debug)]
pub struct Node { pub struct Node<T: FloatExt> {
index: usize, index: usize,
output: f64, output: T,
split_feature: usize, split_feature: usize,
split_value: f64, split_value: T,
split_score: f64, split_score: T,
true_child: Option<usize>, true_child: Option<usize>,
false_child: Option<usize>, false_child: Option<usize>,
} }
@@ -39,32 +42,32 @@ impl Default for DecisionTreeRegressorParameters {
} }
} }
impl Node { impl<T: FloatExt> Node<T> {
fn new(index: usize, output: f64) -> Self { fn new(index: usize, output: T) -> Self {
Node { Node {
index: index, index: index,
output: output, output: output,
split_feature: 0, split_feature: 0,
split_value: std::f64::NAN, split_value: T::nan(),
split_score: std::f64::NAN, split_score: T::nan(),
true_child: Option::None, true_child: Option::None,
false_child: Option::None false_child: Option::None
} }
} }
} }
struct NodeVisitor<'a, M: Matrix> { struct NodeVisitor<'a, T: FloatExt + Debug, M: Matrix<T>> {
x: &'a M, x: &'a M,
y: &'a M, y: &'a M,
node: usize, node: usize,
samples: Vec<usize>, samples: Vec<usize>,
order: &'a Vec<Vec<usize>>, order: &'a Vec<Vec<usize>>,
true_child_output: f64, true_child_output: T,
false_child_output: f64, false_child_output: T,
level: u16 level: u16
} }
impl<'a, M: Matrix> NodeVisitor<'a, M> { impl<'a, T: FloatExt + Debug, M: Matrix<T>> NodeVisitor<'a, T, M> {
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a M, level: u16) -> Self { fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a M, level: u16) -> Self {
NodeVisitor { NodeVisitor {
@@ -73,23 +76,23 @@ impl<'a, M: Matrix> NodeVisitor<'a, M> {
node: node_id, node: node_id,
samples: samples, samples: samples,
order: order, order: order,
true_child_output: 0., true_child_output: T::zero(),
false_child_output: 0., false_child_output: T::zero(),
level: level level: level
} }
} }
} }
impl DecisionTreeRegressor { impl<T: FloatExt + Debug> DecisionTreeRegressor<T> {
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor { pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor<T> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
} }
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor { pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor<T> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
@@ -99,16 +102,16 @@ impl DecisionTreeRegressor {
panic!("Incorrect number of classes: {}. Should be >= 2.", k); panic!("Incorrect number of classes: {}. Should be >= 2.", k);
} }
let mut nodes: Vec<Node> = Vec::new(); let mut nodes: Vec<Node<T>> = Vec::new();
let mut n = 0; let mut n = 0;
let mut sum = 0f64; let mut sum = T::zero();
for i in 0..y_ncols { for i in 0..y_ncols {
n += samples[i]; n += samples[i];
sum += samples[i] as f64 * y_m.get(i, 0); sum = sum + T::from(samples[i]).unwrap() * y_m.get(i, 0);
} }
let root = Node::new(0, sum / n as f64); let root = Node::new(0, sum / T::from(n).unwrap());
nodes.push(root); nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new(); let mut order: Vec<Vec<usize>> = Vec::new();
@@ -122,9 +125,9 @@ impl DecisionTreeRegressor {
depth: 0 depth: 0
}; };
let mut visitor = NodeVisitor::<M>::new(0, samples, &order, &x, &y_m, 1); let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &y_m, 1);
let mut visitor_queue: LinkedList<NodeVisitor<M>> = LinkedList::new(); let mut visitor_queue: LinkedList<NodeVisitor<T, M>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry) { if tree.find_best_cutoff(&mut visitor, mtry) {
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
@@ -140,7 +143,7 @@ impl DecisionTreeRegressor {
tree tree
} }
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector { pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0); let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -152,8 +155,8 @@ impl DecisionTreeRegressor {
result.to_row_vector() result.to_row_vector()
} }
pub(in crate) fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> f64 { pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
let mut result = 0f64; let mut result = T::zero();
let mut queue: LinkedList<usize> = LinkedList::new(); let mut queue: LinkedList<usize> = LinkedList::new();
queue.push_back(0); queue.push_back(0);
@@ -180,7 +183,7 @@ impl DecisionTreeRegressor {
} }
fn find_best_cutoff<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, mtry: usize) -> bool { fn find_best_cutoff<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, mtry: usize) -> bool {
let (_, n_attr) = visitor.x.shape(); let (_, n_attr) = visitor.x.shape();
@@ -190,14 +193,14 @@ impl DecisionTreeRegressor {
return false; return false;
} }
let sum = self.nodes[visitor.node].output * n as f64; let sum = self.nodes[visitor.node].output * T::from(n).unwrap();
let mut variables = vec![0; n_attr]; let mut variables = vec![0; n_attr];
for i in 0..n_attr { for i in 0..n_attr {
variables[i] = i; variables[i] = i;
} }
let parent_gain = n as f64 * self.nodes[visitor.node].output * self.nodes[visitor.node].output; let parent_gain = T::from(n).unwrap() * self.nodes[visitor.node].output * self.nodes[visitor.node].output;
for j in 0..mtry { for j in 0..mtry {
self.find_best_split(visitor, n, sum, parent_gain, variables[j]); self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
@@ -207,18 +210,18 @@ impl DecisionTreeRegressor {
} }
fn find_best_split<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, n: usize, sum: f64, parent_gain: f64, j: usize){ fn find_best_split<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, n: usize, sum: T, parent_gain: T, j: usize){
let mut true_sum = 0f64; let mut true_sum = T::zero();
let mut true_count = 0; let mut true_count = 0;
let mut prevx = std::f64::NAN; let mut prevx = T::nan();
for i in visitor.order[j].iter() { for i in visitor.order[j].iter() {
if visitor.samples[*i] > 0 { if visitor.samples[*i] > 0 {
if prevx.is_nan() || visitor.x.get(*i, j) == prevx { if prevx.is_nan() || visitor.x.get(*i, j) == prevx {
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
true_count += visitor.samples[*i]; true_count += visitor.samples[*i];
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i, 0); true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(*i, 0);
continue; continue;
} }
@@ -227,32 +230,32 @@ impl DecisionTreeRegressor {
if true_count < self.parameters.min_samples_leaf || false_count < self.parameters.min_samples_leaf { if true_count < self.parameters.min_samples_leaf || false_count < self.parameters.min_samples_leaf {
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
true_count += visitor.samples[*i]; true_count += visitor.samples[*i];
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i, 0); true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(*i, 0);
continue; continue;
} }
let true_mean = true_sum / true_count as f64; let true_mean = true_sum / T::from(true_count).unwrap();
let false_mean = (sum - true_sum) / false_count as f64; let false_mean = (sum - true_sum) / T::from(false_count).unwrap();
let gain = (true_count as f64 * true_mean * true_mean + false_count as f64 * false_mean * false_mean) - parent_gain; let gain = (T::from(true_count).unwrap() * true_mean * true_mean + T::from(false_count).unwrap() * false_mean * false_mean) - parent_gain;
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score { if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
self.nodes[visitor.node].split_feature = j; self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / 2.; self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
self.nodes[visitor.node].split_score = gain; self.nodes[visitor.node].split_score = gain;
visitor.true_child_output = true_mean; visitor.true_child_output = true_mean;
visitor.false_child_output = false_mean; visitor.false_child_output = false_mean;
} }
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i, 0); true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(*i, 0);
true_count += visitor.samples[*i]; true_count += visitor.samples[*i];
} }
} }
} }
fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, M>>) -> bool { fn split<'a, M: Matrix<T>>(&mut self, mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>) -> bool {
let (n, _) = visitor.x.shape(); let (n, _) = visitor.x.shape();
let mut tc = 0; let mut tc = 0;
let mut fc = 0; let mut fc = 0;
@@ -272,8 +275,8 @@ impl DecisionTreeRegressor {
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf { if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
self.nodes[visitor.node].split_feature = 0; self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = std::f64::NAN; self.nodes[visitor.node].split_value = T::nan();
self.nodes[visitor.node].split_score = std::f64::NAN; self.nodes[visitor.node].split_score = T::nan();
return false; return false;
} }
@@ -287,13 +290,13 @@ impl DecisionTreeRegressor {
self.depth = u16::max(self.depth, visitor.level + 1); self.depth = u16::max(self.depth, visitor.level + 1);
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); let mut true_visitor = NodeVisitor::<T, M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
if self.find_best_cutoff(&mut true_visitor, mtry) { if self.find_best_cutoff(&mut true_visitor, mtry) {
visitor_queue.push_back(true_visitor); visitor_queue.push_back(true_visitor);
} }
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); let mut false_visitor = NodeVisitor::<T, M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
if self.find_best_cutoff(&mut false_visitor, mtry) { if self.find_best_cutoff(&mut false_visitor, mtry) {
visitor_queue.push_back(false_visitor); visitor_queue.push_back(false_visitor);
@@ -329,7 +332,7 @@ mod tests {
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); &[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x); let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);