Adds SVD solver, code refactoring
This commit is contained in:
@@ -13,7 +13,7 @@ where T: Debug
|
|||||||
base: f64,
|
base: f64,
|
||||||
max_level: i8,
|
max_level: i8,
|
||||||
min_level: i8,
|
min_level: i8,
|
||||||
distance: &'a Fn(&T, &T) -> f64,
|
distance: &'a dyn Fn(&T, &T) -> f64,
|
||||||
nodes: Vec<Node<T>>
|
nodes: Vec<Node<T>>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ impl<'a, T> CoverTree<'a, T>
|
|||||||
where T: Debug
|
where T: Debug
|
||||||
{
|
{
|
||||||
|
|
||||||
pub fn new(mut data: Vec<T>, distance: &'a Fn(&T, &T) -> f64) -> CoverTree<T> {
|
pub fn new(mut data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> f64) -> CoverTree<T> {
|
||||||
let mut tree = CoverTree {
|
let mut tree = CoverTree {
|
||||||
base: 2f64,
|
base: 2f64,
|
||||||
max_level: 100,
|
max_level: 100,
|
||||||
@@ -49,7 +49,7 @@ where T: Debug
|
|||||||
let i_d = self.base.powf(i as f64);
|
let i_d = self.base.powf(i as f64);
|
||||||
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||||
let d_p_q = self.min_by_distance(&q_p_ds);
|
let d_p_q = self.min_by_distance(&q_p_ds);
|
||||||
if d_p_q < math::SMALL_ERROR {
|
if d_p_q < math::EPSILON {
|
||||||
return
|
return
|
||||||
} else if d_p_q > i_d {
|
} else if d_p_q > i_d {
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use std::cmp::{Ordering, PartialOrd};
|
|||||||
use num_traits::Float;
|
use num_traits::Float;
|
||||||
|
|
||||||
pub struct LinearKNNSearch<'a, T> {
|
pub struct LinearKNNSearch<'a, T> {
|
||||||
distance: Box<Fn(&T, &T) -> f64 + 'a>,
|
distance: Box<dyn Fn(&T, &T) -> f64 + 'a>,
|
||||||
data: Vec<T>
|
data: Vec<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ impl<'a, T> KNNAlgorithm<T> for LinearKNNSearch<'a, T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> LinearKNNSearch<'a, T> {
|
impl<'a, T> LinearKNNSearch<'a, T> {
|
||||||
pub fn new(data: Vec<T>, distance: &'a Fn(&T, &T) -> f64) -> LinearKNNSearch<T>{
|
pub fn new(data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> f64) -> LinearKNNSearch<T>{
|
||||||
LinearKNNSearch{
|
LinearKNNSearch{
|
||||||
data: data,
|
data: data,
|
||||||
distance: Box::new(distance)
|
distance: Box::new(distance)
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
use std::cmp::Ordering;
|
use std::cmp::Ordering;
|
||||||
use std::mem;
|
|
||||||
use std::fmt::Display;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct HeapSelect<T: PartialOrd> {
|
pub struct HeapSelect<T: PartialOrd> {
|
||||||
|
|
||||||
k: usize,
|
k: usize,
|
||||||
n: usize,
|
n: usize,
|
||||||
sorted: bool,
|
sorted: bool,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ use ndarray::{ArrayBase, Data, Ix1, Ix2};
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
|
||||||
type F<X> = Fn(&X, &X) -> f64;
|
type F<X> = dyn Fn(&X, &X) -> f64;
|
||||||
|
|
||||||
pub struct KNNClassifier<'a, X, Y>
|
pub struct KNNClassifier<'a, X, Y>
|
||||||
where
|
where
|
||||||
@@ -17,7 +17,7 @@ where
|
|||||||
{
|
{
|
||||||
classes: Vec<Y>,
|
classes: Vec<Y>,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
knn_algorithm: Box<KNNAlgorithm<X> + 'a>,
|
knn_algorithm: Box<dyn KNNAlgorithm<X> + 'a>,
|
||||||
k: usize,
|
k: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ where
|
|||||||
let classes: Vec<Y> = c_hash.into_iter().collect();
|
let classes: Vec<Y> = c_hash.into_iter().collect();
|
||||||
let y_i:Vec<usize> = y.into_iter().map(|y| classes.iter().position(|yy| yy == &y).unwrap()).collect();
|
let y_i:Vec<usize> = y.into_iter().map(|y| classes.iter().position(|yy| yy == &y).unwrap()).collect();
|
||||||
|
|
||||||
let knn_algorithm: Box<KNNAlgorithm<X> + 'a> = match algorithm {
|
let knn_algorithm: Box<dyn KNNAlgorithm<X> + 'a> = match algorithm {
|
||||||
KNNAlgorithmName::CoverTree => Box::new(CoverTree::<X>::new(x, distance)),
|
KNNAlgorithmName::CoverTree => Box::new(CoverTree::<X>::new(x, distance)),
|
||||||
KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::<X>::new(x, distance))
|
KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::<X>::new(x, distance))
|
||||||
};
|
};
|
||||||
|
|||||||
+3
-1
@@ -2,12 +2,14 @@ use std::ops::Range;
|
|||||||
|
|
||||||
pub mod naive;
|
pub mod naive;
|
||||||
|
|
||||||
pub trait Matrix: Into<Vec<f64>>{
|
pub trait Matrix: Into<Vec<f64>> + Clone{
|
||||||
|
|
||||||
fn get(&self, row: usize, col: usize) -> f64;
|
fn get(&self, row: usize, col: usize) -> f64;
|
||||||
|
|
||||||
fn qr_solve_mut(&mut self, b: Self) -> Self;
|
fn qr_solve_mut(&mut self, b: Self) -> Self;
|
||||||
|
|
||||||
|
fn svd_solve_mut(&mut self, b: Self) -> Self;
|
||||||
|
|
||||||
fn zeros(nrows: usize, ncols: usize) -> Self;
|
fn zeros(nrows: usize, ncols: usize) -> Self;
|
||||||
|
|
||||||
fn ones(nrows: usize, ncols: usize) -> Self;
|
fn ones(nrows: usize, ncols: usize) -> Self;
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::ops::Range;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math;
|
use crate::math;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct DenseMatrix {
|
pub struct DenseMatrix {
|
||||||
|
|
||||||
ncols: usize,
|
ncols: usize,
|
||||||
@@ -63,6 +63,10 @@ impl DenseMatrix {
|
|||||||
self.values[col*self.nrows + row] /= x;
|
self.values[col*self.nrows + row] /= x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mul_element_mut(&mut self, row: usize, col: usize, x: f64) {
|
||||||
|
self.values[col*self.nrows + row] *= x;
|
||||||
|
}
|
||||||
|
|
||||||
fn add_element_mut(&mut self, row: usize, col: usize, x: f64) {
|
fn add_element_mut(&mut self, row: usize, col: usize, x: f64) {
|
||||||
self.values[col*self.nrows + row] += x
|
self.values[col*self.nrows + row] += x
|
||||||
}
|
}
|
||||||
@@ -87,7 +91,7 @@ impl PartialEq for DenseMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i in 0..len {
|
for i in 0..len {
|
||||||
if (self.values[i] - other.values[i]).abs() > math::SMALL_ERROR {
|
if (self.values[i] - other.values[i]).abs() > math::EPSILON {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -195,6 +199,10 @@ impl Matrix for DenseMatrix {
|
|||||||
let n = self.ncols;
|
let n = self.ncols;
|
||||||
let nrhs = b.ncols;
|
let nrhs = b.ncols;
|
||||||
|
|
||||||
|
if self.nrows != b.nrows {
|
||||||
|
panic!("Dimensions do not agree. Self.nrows should equal b.nrows but is {}, {}", self.nrows, b.nrows);
|
||||||
|
}
|
||||||
|
|
||||||
let mut r_diagonal: Vec<f64> = vec![0f64; n];
|
let mut r_diagonal: Vec<f64> = vec![0f64; n];
|
||||||
|
|
||||||
for k in 0..n {
|
for k in 0..n {
|
||||||
@@ -203,7 +211,7 @@ impl Matrix for DenseMatrix {
|
|||||||
nrm = nrm.hypot(self.get(i, k));
|
nrm = nrm.hypot(self.get(i, k));
|
||||||
}
|
}
|
||||||
|
|
||||||
if nrm > math::SMALL_ERROR {
|
if nrm.abs() > math::EPSILON {
|
||||||
|
|
||||||
if self.get(k, k) < 0f64 {
|
if self.get(k, k) < 0f64 {
|
||||||
nrm = -nrm;
|
nrm = -nrm;
|
||||||
@@ -228,7 +236,7 @@ impl Matrix for DenseMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for j in 0..r_diagonal.len() {
|
for j in 0..r_diagonal.len() {
|
||||||
if r_diagonal[j].abs() < math::SMALL_ERROR {
|
if r_diagonal[j].abs() < math::EPSILON {
|
||||||
panic!("Matrix is rank deficient.");
|
panic!("Matrix is rank deficient.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -262,6 +270,378 @@ impl Matrix for DenseMatrix {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn svd_solve_mut(&mut self, mut b: DenseMatrix) -> DenseMatrix {
|
||||||
|
|
||||||
|
if self.nrows != b.nrows {
|
||||||
|
panic!("Dimensions do not agree. Self.nrows should equal b.nrows but is {}, {}", self.nrows, b.nrows);
|
||||||
|
}
|
||||||
|
|
||||||
|
let m = self.nrows;
|
||||||
|
let n = self.ncols;
|
||||||
|
|
||||||
|
let (mut l, mut nm) = (0usize, 0usize);
|
||||||
|
let (mut anorm, mut g, mut scale) = (0f64, 0f64, 0f64);
|
||||||
|
|
||||||
|
let mut v = DenseMatrix::zeros(n, n);
|
||||||
|
let mut w = vec![0f64; n];
|
||||||
|
let mut rv1 = vec![0f64; n];
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
l = i + 2;
|
||||||
|
rv1[i] = scale * g;
|
||||||
|
g = 0f64;
|
||||||
|
let mut s = 0f64;
|
||||||
|
scale = 0f64;
|
||||||
|
|
||||||
|
if i < m {
|
||||||
|
for k in i..m {
|
||||||
|
scale += self.get(k, i).abs();
|
||||||
|
}
|
||||||
|
|
||||||
|
if scale.abs() > math::EPSILON {
|
||||||
|
|
||||||
|
for k in i..m {
|
||||||
|
self.div_element_mut(k, i, scale);
|
||||||
|
s += self.get(k, i) * self.get(k, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut f = self.get(i, i);
|
||||||
|
g = -s.sqrt().copysign(f);
|
||||||
|
let h = f * g - s;
|
||||||
|
self.set(i, i, f - g);
|
||||||
|
for j in l - 1..n {
|
||||||
|
s = 0f64;
|
||||||
|
for k in i..m {
|
||||||
|
s += self.get(k, i) * self.get(k, j);
|
||||||
|
}
|
||||||
|
f = s / h;
|
||||||
|
for k in i..m {
|
||||||
|
self.add_element_mut(k, j, f * self.get(k, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k in i..m {
|
||||||
|
self.mul_element_mut(k, i, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w[i] = scale * g;
|
||||||
|
g = 0f64;
|
||||||
|
let mut s = 0f64;
|
||||||
|
scale = 0f64;
|
||||||
|
|
||||||
|
if i + 1 <= m && i + 1 != n {
|
||||||
|
for k in l - 1..n {
|
||||||
|
scale += self.get(i, k).abs();
|
||||||
|
}
|
||||||
|
|
||||||
|
if scale.abs() > math::EPSILON {
|
||||||
|
for k in l - 1..n {
|
||||||
|
self.div_element_mut(i, k, scale);
|
||||||
|
s += self.get(i, k) * self.get(i, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
let f = self.get(i, l - 1);
|
||||||
|
g = -s.sqrt().copysign(f);
|
||||||
|
let h = f * g - s;
|
||||||
|
self.set(i, l - 1, f - g);
|
||||||
|
|
||||||
|
for k in l - 1..n {
|
||||||
|
rv1[k] = self.get(i, k) / h;
|
||||||
|
}
|
||||||
|
|
||||||
|
for j in l - 1..m {
|
||||||
|
s = 0f64;
|
||||||
|
for k in l - 1..n {
|
||||||
|
s += self.get(j, k) * self.get(i, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in l - 1..n {
|
||||||
|
self.add_element_mut(j, k, s * rv1[k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in l - 1..n {
|
||||||
|
self.mul_element_mut(i, k, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
anorm = f64::max(anorm, w[i].abs() + rv1[i].abs());
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in (0..n).rev() {
|
||||||
|
if i < n - 1 {
|
||||||
|
if g != 0.0 {
|
||||||
|
for j in l..n {
|
||||||
|
v.set(j, i, (self.get(i, j) / self.get(i, l)) / g);
|
||||||
|
}
|
||||||
|
for j in l..n {
|
||||||
|
let mut s = 0f64;
|
||||||
|
for k in l..n {
|
||||||
|
s += self.get(i, k) * v.get(k, j);
|
||||||
|
}
|
||||||
|
for k in l..n {
|
||||||
|
v.add_element_mut(k, j, s * v.get(k, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in l..n {
|
||||||
|
v.set(i, j, 0f64);
|
||||||
|
v.set(j, i, 0f64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v.set(i, i, 1.0);
|
||||||
|
g = rv1[i];
|
||||||
|
l = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in (0..usize::min(m, n)).rev() {
|
||||||
|
l = i + 1;
|
||||||
|
g = w[i];
|
||||||
|
for j in l..n {
|
||||||
|
self.set(i, j, 0f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.abs() > math::EPSILON {
|
||||||
|
g = 1f64 / g;
|
||||||
|
for j in l..n {
|
||||||
|
let mut s = 0f64;
|
||||||
|
for k in l..m {
|
||||||
|
s += self.get(k, i) * self.get(k, j);
|
||||||
|
}
|
||||||
|
let f = (s / self.get(i, i)) * g;
|
||||||
|
for k in i..m {
|
||||||
|
self.add_element_mut(k, j, f * self.get(k, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in i..m {
|
||||||
|
self.mul_element_mut(j, i, g);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for j in i..m {
|
||||||
|
self.set(j, i, 0f64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.add_element_mut(i, i, 1f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in (0..n).rev() {
|
||||||
|
for iteration in 0..30 {
|
||||||
|
let mut flag = true;
|
||||||
|
l = k;
|
||||||
|
while l != 0 {
|
||||||
|
if l == 0 || rv1[l].abs() <= math::EPSILON * anorm {
|
||||||
|
flag = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
nm = l - 1;
|
||||||
|
if w[nm].abs() <= math::EPSILON * anorm {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
l -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if flag {
|
||||||
|
let mut c = 0.0;
|
||||||
|
let mut s = 1.0;
|
||||||
|
for i in l..k+1 {
|
||||||
|
let f = s * rv1[i];
|
||||||
|
rv1[i] = c * rv1[i];
|
||||||
|
if f.abs() <= math::EPSILON * anorm {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
g = w[i];
|
||||||
|
let mut h = f.hypot(g);
|
||||||
|
w[i] = h;
|
||||||
|
h = 1.0 / h;
|
||||||
|
c = g * h;
|
||||||
|
s = -f * h;
|
||||||
|
for j in 0..m {
|
||||||
|
let y = self.get(j, nm);
|
||||||
|
let z = self.get(j, i);
|
||||||
|
self.set(j, nm, y * c + z * s);
|
||||||
|
self.set(j, i, z * c - y * s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let z = w[k];
|
||||||
|
if l == k {
|
||||||
|
if z < 0f64 {
|
||||||
|
w[k] = -z;
|
||||||
|
for j in 0..n {
|
||||||
|
v.set(j, k, -v.get(j, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if iteration == 29 {
|
||||||
|
panic!("no convergence in 30 iterations");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut x = w[l];
|
||||||
|
nm = k - 1;
|
||||||
|
let mut y = w[nm];
|
||||||
|
g = rv1[nm];
|
||||||
|
let mut h = rv1[k];
|
||||||
|
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2.0 * h * y);
|
||||||
|
g = f.hypot(1.0);
|
||||||
|
f = ((x - z) * (x + z) + h * ((y / (f + g.copysign(f))) - h)) / x;
|
||||||
|
let mut c = 1f64;
|
||||||
|
let mut s = 1f64;
|
||||||
|
|
||||||
|
for j in l..=nm {
|
||||||
|
let i = j + 1;
|
||||||
|
g = rv1[i];
|
||||||
|
y = w[i];
|
||||||
|
h = s * g;
|
||||||
|
g = c * g;
|
||||||
|
let mut z = f.hypot(h);
|
||||||
|
rv1[j] = z;
|
||||||
|
c = f / z;
|
||||||
|
s = h / z;
|
||||||
|
f = x * c + g * s;
|
||||||
|
g = g * c - x * s;
|
||||||
|
h = y * s;
|
||||||
|
y *= c;
|
||||||
|
|
||||||
|
for jj in 0..n {
|
||||||
|
x = v.get(jj, j);
|
||||||
|
z = v.get(jj, i);
|
||||||
|
v.set(jj, j, x * c + z * s);
|
||||||
|
v.set(jj, i, z * c - x * s);
|
||||||
|
}
|
||||||
|
|
||||||
|
z = f.hypot(h);
|
||||||
|
w[j] = z;
|
||||||
|
if z.abs() > math::EPSILON {
|
||||||
|
z = 1.0 / z;
|
||||||
|
c = f * z;
|
||||||
|
s = h * z;
|
||||||
|
}
|
||||||
|
|
||||||
|
f = c * g + s * y;
|
||||||
|
x = c * y - s * g;
|
||||||
|
for jj in 0..m {
|
||||||
|
y = self.get(jj, j);
|
||||||
|
z = self.get(jj, i);
|
||||||
|
self.set(jj, j, y * c + z * s);
|
||||||
|
self.set(jj, i, z * c - y * s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rv1[l] = 0.0;
|
||||||
|
rv1[k] = f;
|
||||||
|
w[k] = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut inc = 1usize;
|
||||||
|
let mut su = vec![0f64; m];
|
||||||
|
let mut sv = vec![0f64; n];
|
||||||
|
|
||||||
|
loop {
|
||||||
|
inc *= 3;
|
||||||
|
inc += 1;
|
||||||
|
if inc > n {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
inc /= 3;
|
||||||
|
for i in inc..n {
|
||||||
|
let sw = w[i];
|
||||||
|
for k in 0..m {
|
||||||
|
su[k] = self.get(k, i);
|
||||||
|
}
|
||||||
|
for k in 0..n {
|
||||||
|
sv[k] = v.get(k, i);
|
||||||
|
}
|
||||||
|
let mut j = i;
|
||||||
|
while w[j - inc] < sw {
|
||||||
|
w[j] = w[j - inc];
|
||||||
|
for k in 0..m {
|
||||||
|
self.set(k, j, self.get(k, j - inc));
|
||||||
|
}
|
||||||
|
for k in 0..n {
|
||||||
|
v.set(k, j, v.get(k, j - inc));
|
||||||
|
}
|
||||||
|
j -= inc;
|
||||||
|
if j < inc {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w[j] = sw;
|
||||||
|
for k in 0..m {
|
||||||
|
self.set(k, j, su[k]);
|
||||||
|
}
|
||||||
|
for k in 0..n {
|
||||||
|
v.set(k, j, sv[k]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
if inc <= 1 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in 0..n {
|
||||||
|
let mut s = 0.;
|
||||||
|
for i in 0..m {
|
||||||
|
if self.get(i, k) < 0. {
|
||||||
|
s += 1.;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in 0..n {
|
||||||
|
if v.get(j, k) < 0. {
|
||||||
|
s += 1.;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s > (m + n) as f64 / 2. {
|
||||||
|
for i in 0..m {
|
||||||
|
self.set(i, k, -self.get(i, k));
|
||||||
|
}
|
||||||
|
for j in 0..n {
|
||||||
|
v.set(j, k, -v.get(j, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let tol = 0.5 * ((m + n) as f64 + 1.).sqrt() * w[0] * math::EPSILON;
|
||||||
|
|
||||||
|
let p = b.ncols;
|
||||||
|
|
||||||
|
for k in 0..p {
|
||||||
|
let mut tmp = vec![0f64; v.nrows];
|
||||||
|
for j in 0..n {
|
||||||
|
let mut r = 0f64;
|
||||||
|
if w[j] > tol {
|
||||||
|
for i in 0..m {
|
||||||
|
r += self.get(i, j) * b.get(i, k);
|
||||||
|
}
|
||||||
|
r /= w[j];
|
||||||
|
}
|
||||||
|
tmp[j] = r;
|
||||||
|
}
|
||||||
|
|
||||||
|
for j in 0..n {
|
||||||
|
let mut r = 0.0;
|
||||||
|
for jj in 0..n {
|
||||||
|
r += v.get(j, jj) * tmp[jj];
|
||||||
|
}
|
||||||
|
b.set(j, k, r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
fn approximate_eq(&self, other: &Self, error: f64) -> bool {
|
fn approximate_eq(&self, other: &Self, error: f64) -> bool {
|
||||||
if self.ncols != other.ncols || self.nrows != other.nrows {
|
if self.ncols != other.ncols || self.nrows != other.nrows {
|
||||||
return false
|
return false
|
||||||
@@ -304,9 +684,19 @@ mod tests {
|
|||||||
|
|
||||||
let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
|
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
|
||||||
let expected_w = DenseMatrix::from_array(3, 2, &[-0.20270270270270263, 0.8783783783783784, 0.4729729729729729, -1.2837837837837829, 2.2297297297297303, 0.6621621621621613]);
|
let expected_w = DenseMatrix::from_array(3, 2, &[-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
|
||||||
let w = a.qr_solve_mut(b);
|
let w = a.qr_solve_mut(b);
|
||||||
assert_eq!(w, expected_w);
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn svd_solve_mut() {
|
||||||
|
|
||||||
|
let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
|
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
|
||||||
|
let expected_w = DenseMatrix::from_array(3, 2, &[-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
|
||||||
|
let w = a.svd_solve_mut(b);
|
||||||
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
+1
-1
@@ -1,3 +1,3 @@
|
|||||||
pub mod distance;
|
pub mod distance;
|
||||||
|
|
||||||
pub static SMALL_ERROR:f64 = 0.0000000000000001f64;
|
pub static EPSILON:f64 = 2.2204460492503131e-16_f64;
|
||||||
@@ -26,10 +26,14 @@ impl<M: Matrix> LinearRegression<M> {
|
|||||||
panic!("Number of rows of X doesn't match number of rows of Y");
|
panic!("Number of rows of X doesn't match number of rows of Y");
|
||||||
}
|
}
|
||||||
|
|
||||||
let b = y.v_stack(&M::ones(1, 1));
|
// let b = y.v_stack(&M::ones(1, 1));
|
||||||
|
let b = y.clone();
|
||||||
let mut a = x.h_stack(&M::ones(x_nrows, 1));
|
let mut a = x.h_stack(&M::ones(x_nrows, 1));
|
||||||
|
|
||||||
let w = a.qr_solve_mut(b);
|
let w = match solver {
|
||||||
|
LinearRegressionSolver::QR => a.qr_solve_mut(b),
|
||||||
|
LinearRegressionSolver::SVD => a.svd_solve_mut(b)
|
||||||
|
};
|
||||||
|
|
||||||
let wights = w.slice(0..num_attributes, 0..1);
|
let wights = w.slice(0..num_attributes, 0..1);
|
||||||
|
|
||||||
@@ -45,7 +49,7 @@ impl<M: Matrix> LinearRegression<M> {
|
|||||||
impl<M: Matrix> Regression<M> for LinearRegression<M> {
|
impl<M: Matrix> Regression<M> for LinearRegression<M> {
|
||||||
|
|
||||||
|
|
||||||
fn predict(&self, x: M) -> M {
|
fn predict(&self, x: &M) -> M {
|
||||||
let (nrows, _) = x.shape();
|
let (nrows, _) = x.shape();
|
||||||
let mut y_hat = x.dot(&self.coefficients);
|
let mut y_hat = x.dot(&self.coefficients);
|
||||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||||
@@ -81,10 +85,13 @@ mod tests {
|
|||||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||||
let y = DenseMatrix::from_array(16, 1, &[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]);
|
let y = DenseMatrix::from_array(16, 1, &[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]);
|
||||||
|
|
||||||
let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR);
|
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
||||||
|
|
||||||
|
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
|
||||||
|
|
||||||
|
assert!(y.approximate_eq(&y_hat_qr, 5.));
|
||||||
|
assert!(y.approximate_eq(&y_hat_svd, 5.));
|
||||||
|
|
||||||
let y_hat = lr.predict(x);
|
|
||||||
|
|
||||||
assert!(y.approximate_eq(&y_hat, 5.));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4,6 +4,6 @@ use crate::linalg::Matrix;
|
|||||||
|
|
||||||
pub trait Regression<M: Matrix> {
|
pub trait Regression<M: Matrix> {
|
||||||
|
|
||||||
fn predict(&self, x: M) -> M;
|
fn predict(&self, x: &M) -> M;
|
||||||
|
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user