feat: extends interface of Matrix to support for broad range of types
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
use crate::math::NumericExt;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::optimization::FunctionOrder;
|
||||
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||
@@ -6,42 +9,43 @@ use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::first_order::lbfgs::LBFGS;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LogisticRegression<M: Matrix> {
|
||||
pub struct LogisticRegression<T: FloatExt + Debug, M: Matrix<T>> {
|
||||
weights: M,
|
||||
classes: Vec<f64>,
|
||||
classes: Vec<T>,
|
||||
num_attributes: usize,
|
||||
num_classes: usize
|
||||
}
|
||||
|
||||
trait ObjectiveFunction<M: Matrix> {
|
||||
fn f(&self, w_bias: &M) -> f64;
|
||||
trait ObjectiveFunction<T: FloatExt + Debug, M: Matrix<T>> {
|
||||
fn f(&self, w_bias: &M) -> T;
|
||||
fn df(&self, g: &mut M, w_bias: &M);
|
||||
|
||||
fn partial_dot(w: &M, x: &M, v_col: usize, m_row: usize) -> f64 {
|
||||
let mut sum = 0f64;
|
||||
fn partial_dot(w: &M, x: &M, v_col: usize, m_row: usize) -> T {
|
||||
let mut sum = T::zero();
|
||||
let p = x.shape().1;
|
||||
for i in 0..p {
|
||||
sum += x.get(m_row, i) * w.get(0, i + v_col);
|
||||
sum = sum + x.get(m_row, i) * w.get(0, i + v_col);
|
||||
}
|
||||
|
||||
sum + w.get(0, p + v_col)
|
||||
}
|
||||
}
|
||||
|
||||
struct BinaryObjectiveFunction<'a, M: Matrix> {
|
||||
struct BinaryObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: Vec<usize>
|
||||
y: Vec<usize>,
|
||||
phantom: PhantomData<&'a T>
|
||||
}
|
||||
|
||||
impl<'a, M: Matrix> ObjectiveFunction<M> for BinaryObjectiveFunction<'a, M> {
|
||||
impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
|
||||
|
||||
fn f(&self, w_bias: &M) -> f64 {
|
||||
let mut f = 0.;
|
||||
fn f(&self, w_bias: &M) -> T {
|
||||
let mut f = T::zero();
|
||||
let (n, _) = self.x.shape();
|
||||
|
||||
for i in 0..n {
|
||||
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
|
||||
f += wx.ln_1pe() - (self.y[i] as f64) * wx;
|
||||
f = f + (wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx);
|
||||
}
|
||||
|
||||
f
|
||||
@@ -57,7 +61,7 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for BinaryObjectiveFunction<'a, M> {
|
||||
|
||||
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
|
||||
|
||||
let dyi = (self.y[i] as f64) - wx.sigmoid();
|
||||
let dyi = (T::from(self.y[i]).unwrap()) - wx.sigmoid();
|
||||
for j in 0..p {
|
||||
g.set(0, j, g.get(0, j) - dyi * self.x.get(i, j));
|
||||
}
|
||||
@@ -68,16 +72,17 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for BinaryObjectiveFunction<'a, M> {
|
||||
|
||||
}
|
||||
|
||||
struct MultiClassObjectiveFunction<'a, M: Matrix> {
|
||||
struct MultiClassObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: Vec<usize>,
|
||||
k: usize
|
||||
k: usize,
|
||||
phantom: PhantomData<&'a T>
|
||||
}
|
||||
|
||||
impl<'a, M: Matrix> ObjectiveFunction<M> for MultiClassObjectiveFunction<'a, M> {
|
||||
impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObjectiveFunction<'a, T, M> {
|
||||
|
||||
fn f(&self, w_bias: &M) -> f64 {
|
||||
let mut f = 0.;
|
||||
fn f(&self, w_bias: &M) -> T {
|
||||
let mut f = T::zero();
|
||||
let mut prob = M::zeros(1, self.k);
|
||||
let (n, p) = self.x.shape();
|
||||
for i in 0..n {
|
||||
@@ -85,7 +90,7 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for MultiClassObjectiveFunction<'a, M>
|
||||
prob.set(0, j, MultiClassObjectiveFunction::partial_dot(w_bias, self.x, j * (p + 1), i));
|
||||
}
|
||||
prob.softmax_mut();
|
||||
f -= prob.get(0, self.y[i]).ln();
|
||||
f = f - prob.get(0, self.y[i]).ln();
|
||||
}
|
||||
|
||||
f
|
||||
@@ -106,7 +111,7 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for MultiClassObjectiveFunction<'a, M>
|
||||
prob.softmax_mut();
|
||||
|
||||
for j in 0..self.k {
|
||||
let yi =(if self.y[i] == j { 1.0 } else { 0.0 }) - prob.get(0, j);
|
||||
let yi =(if self.y[i] == j { T::one() } else { T::zero() }) - prob.get(0, j);
|
||||
|
||||
for l in 0..p {
|
||||
let pos = j * (p + 1);
|
||||
@@ -120,9 +125,9 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for MultiClassObjectiveFunction<'a, M>
|
||||
|
||||
}
|
||||
|
||||
impl<M: Matrix> LogisticRegression<M> {
|
||||
impl<T: FloatExt + Debug, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
|
||||
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<M>{
|
||||
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M>{
|
||||
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
@@ -153,7 +158,8 @@ impl<M: Matrix> LogisticRegression<M> {
|
||||
|
||||
let objective = BinaryObjectiveFunction{
|
||||
x: x,
|
||||
y: yi
|
||||
y: yi,
|
||||
phantom: PhantomData
|
||||
};
|
||||
|
||||
let result = LogisticRegression::minimize(x0, objective);
|
||||
@@ -172,7 +178,8 @@ impl<M: Matrix> LogisticRegression<M> {
|
||||
let objective = MultiClassObjectiveFunction{
|
||||
x: x,
|
||||
y: yi,
|
||||
k: k
|
||||
k: k,
|
||||
phantom: PhantomData
|
||||
};
|
||||
|
||||
let result = LogisticRegression::minimize(x0, objective);
|
||||
@@ -196,9 +203,9 @@ impl<M: Matrix> LogisticRegression<M> {
|
||||
if self.num_classes == 2 {
|
||||
let (nrows, _) = x.shape();
|
||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
|
||||
let y_hat: Vec<T> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.classes[if y_hat[i].sigmoid() > 0.5 { 1 } else { 0 }]);
|
||||
result.set(0, i, self.classes[if y_hat[i].sigmoid() > T::half() { 1 } else { 0 }]);
|
||||
}
|
||||
|
||||
} else {
|
||||
@@ -221,8 +228,8 @@ impl<M: Matrix> LogisticRegression<M> {
|
||||
self.weights.slice(0..self.num_classes, self.num_attributes..self.num_attributes+1)
|
||||
}
|
||||
|
||||
fn minimize(x0: M, objective: impl ObjectiveFunction<M>) -> OptimizerResult<M> {
|
||||
let f = |w: &M| -> f64 {
|
||||
fn minimize(x0: M, objective: impl ObjectiveFunction<T, M>) -> OptimizerResult<T, M> {
|
||||
let f = |w: &M| -> T {
|
||||
objective.f(w)
|
||||
};
|
||||
|
||||
@@ -230,9 +237,9 @@ impl<M: Matrix> LogisticRegression<M> {
|
||||
objective.df(g, w)
|
||||
};
|
||||
|
||||
let mut ls: Backtracking = Default::default();
|
||||
let mut ls: Backtracking<T> = Default::default();
|
||||
ls.order = FunctionOrder::THIRD;
|
||||
let optimizer: LBFGS = Default::default();
|
||||
let optimizer: LBFGS<T> = Default::default();
|
||||
|
||||
optimizer.optimize(&f, &df, &x0, &ls)
|
||||
}
|
||||
@@ -270,10 +277,11 @@ mod tests {
|
||||
let objective = MultiClassObjectiveFunction{
|
||||
x: &x,
|
||||
y: y,
|
||||
k: 3
|
||||
k: 3,
|
||||
phantom: PhantomData
|
||||
};
|
||||
|
||||
let mut g = DenseMatrix::zeros(1, 9);
|
||||
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
|
||||
|
||||
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]));
|
||||
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]));
|
||||
@@ -309,10 +317,11 @@ mod tests {
|
||||
|
||||
let objective = BinaryObjectiveFunction{
|
||||
x: &x,
|
||||
y: y
|
||||
y: y,
|
||||
phantom: PhantomData
|
||||
};
|
||||
|
||||
let mut g = DenseMatrix::zeros(1, 3);
|
||||
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
|
||||
|
||||
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3.]));
|
||||
objective.df(&mut g, &DenseMatrix::vector_from_array(&[1., 2., 3.]));
|
||||
@@ -345,7 +354,7 @@ mod tests {
|
||||
&[10., -2.],
|
||||
&[ 8., 2.],
|
||||
&[ 9., 0.]]);
|
||||
let y = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user