fix: more refactoring

This commit is contained in:
Volodymyr Orlov
2020-03-13 11:24:53 -07:00
parent cb4323f26e
commit 4f8318e933
15 changed files with 51 additions and 66 deletions
+2 -3
View File
@@ -139,8 +139,7 @@ impl<M: Matrix> LogisticRegression<M> {
let mut yi: Vec<usize> = vec![0; y_nrows]; let mut yi: Vec<usize> = vec![0; y_nrows];
for i in 0..y_nrows { for i in 0..y_nrows {
let yc = y_m.get(0, i); let yc = y_m.get(0, i);
let j = classes.iter().position(|c| yc == *c).unwrap();
yi[i] = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap();
} }
@@ -244,7 +243,7 @@ impl<M: Matrix> LogisticRegression<M> {
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use ndarray::{arr1, arr2, Array}; use ndarray::{arr1, arr2};
#[test] #[test]
fn multiclass_objective_f() { fn multiclass_objective_f() {
-12
View File
@@ -1,12 +0,0 @@
use std::fmt;
#[derive(Debug)]
pub struct IllegalArgumentError {
pub message: String,
}
impl fmt::Display for IllegalArgumentError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.message)
}
}
-1
View File
@@ -4,7 +4,6 @@ pub mod cluster;
pub mod decomposition; pub mod decomposition;
pub mod linalg; pub mod linalg;
pub mod math; pub mod math;
pub mod error;
pub mod algorithm; pub mod algorithm;
pub mod common; pub mod common;
pub mod optimization; pub mod optimization;
+2
View File
@@ -1,3 +1,5 @@
#![allow(non_snake_case)]
use num::complex::Complex; use num::complex::Complex;
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
+1 -3
View File
@@ -122,9 +122,7 @@ pub trait BaseMatrix: Clone + Debug {
r r
} }
fn transpose(&self) -> Self; fn transpose(&self) -> Self;
fn generate_positive_definite(nrows: usize, ncols: usize) -> Self;
fn rand(nrows: usize, ncols: usize) -> Self; fn rand(nrows: usize, ncols: usize) -> Self;
+1 -11
View File
@@ -366,12 +366,7 @@ impl BaseMatrix for DenseMatrix {
fn sub_element_mut(&mut self, row: usize, col: usize, x: f64) { fn sub_element_mut(&mut self, row: usize, col: usize, x: f64) {
self.values[col*self.nrows + row] -= x; self.values[col*self.nrows + row] -= x;
} }
fn generate_positive_definite(nrows: usize, ncols: usize) -> Self {
let m = DenseMatrix::rand(nrows, ncols);
m.dot(&m.transpose())
}
fn transpose(&self) -> Self { fn transpose(&self) -> Self {
let mut m = DenseMatrix { let mut m = DenseMatrix {
@@ -723,11 +718,6 @@ mod tests {
} }
} }
#[test]
fn generate_positive_definite() {
let m = DenseMatrix::generate_positive_definite(3, 3);
}
#[test] #[test]
fn reshape() { fn reshape() {
let m_orig = DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6.]); let m_orig = DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6.]);
+29 -6
View File
@@ -5,6 +5,7 @@ use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix;
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s}; use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
use rand::prelude::*;
impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2> impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
{ {
@@ -81,7 +82,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
} }
fn approximate_eq(&self, other: &Self, error: f64) -> bool { fn approximate_eq(&self, other: &Self, error: f64) -> bool {
false (self - other).iter().all(|v| v.abs() <= error)
} }
fn add_mut(&mut self, other: &Self) -> &Self { fn add_mut(&mut self, other: &Self) -> &Self {
@@ -128,12 +129,12 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
self.clone().reversed_axes() self.clone().reversed_axes()
} }
fn generate_positive_definite(nrows: usize, ncols: usize) -> Self{
panic!("generate_positive_definite method is not implemented for ndarray");
}
fn rand(nrows: usize, ncols: usize) -> Self{ fn rand(nrows: usize, ncols: usize) -> Self{
panic!("rand method is not implemented for ndarray"); let mut rng = rand::thread_rng();
let values: Vec<f64> = (0..nrows*ncols).map(|_| {
rng.gen()
}).collect();
Array::from_shape_vec((nrows, ncols), values).unwrap()
} }
fn norm2(&self) -> f64{ fn norm2(&self) -> f64{
@@ -600,4 +601,26 @@ mod tests {
let res: Array2<f64> = BaseMatrix::eye(3); let res: Array2<f64> = BaseMatrix::eye(3);
assert_eq!(res, a); assert_eq!(res, a);
} }
#[test]
fn rand() {
let m: Array2<f64> = BaseMatrix::rand(3, 3);
for c in 0..3 {
for r in 0..3 {
assert!(m[[r, c]] != 0f64);
}
}
}
#[test]
fn approximate_eq() {
let a = arr2(&[[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]]);
let noise = arr2(&[[1e-5, 2e-5, 3e-5],
[4e-5, 5e-5, 6e-5],
[7e-5, 8e-5, 9e-5]]);
assert!(a.approximate_eq(&(&noise + &a), 1e-4));
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
}
} }
+2
View File
@@ -1,3 +1,5 @@
#![allow(non_snake_case)]
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
+3 -1
View File
@@ -1,3 +1,5 @@
#![allow(non_snake_case)]
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -504,7 +506,7 @@ mod tests {
#[test] #[test]
fn solve() { fn solve() {
let mut a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]); let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = DenseMatrix::from_array(&[ let expected_w = DenseMatrix::from_array(&[
&[-0.20, -1.28], &[-0.20, -1.28],
+1 -1
View File
@@ -21,7 +21,7 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn measure_simple_euclidian_distance() { fn squared_distance() {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.]; let b = vec![4., 5., 6.];
+1 -3
View File
@@ -9,9 +9,7 @@ pub trait NumericExt {
impl NumericExt for f64 { impl NumericExt for f64 {
fn ln_1pe(&self) -> f64{ fn ln_1pe(&self) -> f64{
let y = 0.;
if *self > 15. { if *self > 15. {
return *self; return *self;
} else { } else {
+1 -1
View File
@@ -154,7 +154,7 @@ impl LBFGS {
g_converged || x_converged || state.counter_f_tol > self.successive_f_tol g_converged || x_converged || state.counter_f_tol > self.successive_f_tol
} }
fn update_hessian<'a, X: Matrix>(&self, df: &'a DF<X>, state: &mut LBFGSState<X>) { fn update_hessian<'a, X: Matrix>(&self, _: &'a DF<X>, state: &mut LBFGSState<X>) {
state.dg = state.x_df.sub(&state.x_df_prev); state.dg = state.x_df.sub(&state.x_df_prev);
let rho_iteration = 1. / state.dx.vector_dot(&state.dg); let rho_iteration = 1. / state.dx.vector_dot(&state.dg);
if !rho_iteration.is_infinite() { if !rho_iteration.is_infinite() {
+2 -4
View File
@@ -1,10 +1,8 @@
pub mod first_order; pub mod first_order;
pub mod line_search; pub mod line_search;
use crate::linalg::Matrix; pub type F<'a, X> = dyn for<'b> Fn(&'b X) -> f64 + 'a;
pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
pub type F<'a, X: Matrix> = dyn for<'b> Fn(&'b X) -> f64 + 'a;
pub type DF<'a, X: Matrix> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum FunctionOrder { pub enum FunctionOrder {
+5 -11
View File
@@ -1,5 +1,4 @@
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::regression::Regression;
use std::fmt::Debug; use std::fmt::Debug;
#[derive(Debug)] #[derive(Debug)]
@@ -27,7 +26,7 @@ impl<M: Matrix> LinearRegression<M> {
panic!("Number of rows of X doesn't match number of rows of Y"); panic!("Number of rows of X doesn't match number of rows of Y");
} }
let mut a = x.v_stack(&M::ones(x_nrows, 1)); let a = x.v_stack(&M::ones(x_nrows, 1));
let w = match solver { let w = match solver {
LinearRegressionSolver::QR => a.qr_solve_mut(b), LinearRegressionSolver::QR => a.qr_solve_mut(b),
@@ -43,16 +42,11 @@ impl<M: Matrix> LinearRegression<M> {
} }
} }
} pub fn predict(&self, x: &M) -> M::RowVector {
impl<M: Matrix> Regression<M> for LinearRegression<M> {
fn predict(&self, x: &M) -> M {
let (nrows, _) = x.shape(); let (nrows, _) = x.shape();
let mut y_hat = x.dot(&self.coefficients); let mut y_hat = x.dot(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept)); y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
y_hat.transpose() y_hat.transpose().to_row_vector()
} }
} }
@@ -85,9 +79,9 @@ mod tests {
let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]); let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]);
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x); let y_hat_qr = DenseMatrix::from_row_vector(LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x));
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x); let y_hat_svd = DenseMatrix::from_row_vector(LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x));
assert!(y.approximate_eq(&y_hat_qr, 5.)); assert!(y.approximate_eq(&y_hat_qr, 5.));
assert!(y.approximate_eq(&y_hat_svd, 5.)); assert!(y.approximate_eq(&y_hat_svd, 5.));
+1 -9
View File
@@ -1,9 +1 @@
pub mod linear_regression; pub mod linear_regression;
use crate::linalg::Matrix;
pub trait Regression<M: Matrix> {
fn predict(&self, x: &M) -> M;
}