fix: renames FloatExt to RealNumber
This commit is contained in:
@@ -4,21 +4,21 @@ use std::marker::PhantomData;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::optimization::first_order::lbfgs::LBFGS;
|
||||
use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||
use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::FunctionOrder;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LogisticRegression<T: FloatExt, M: Matrix<T>> {
|
||||
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
|
||||
weights: M,
|
||||
classes: Vec<T>,
|
||||
num_attributes: usize,
|
||||
num_classes: usize,
|
||||
}
|
||||
|
||||
trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
|
||||
trait ObjectiveFunction<T: RealNumber, M: Matrix<T>> {
|
||||
fn f(&self, w_bias: &M) -> T;
|
||||
fn df(&self, g: &mut M, w_bias: &M);
|
||||
|
||||
@@ -33,13 +33,13 @@ trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
||||
struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: Vec<usize>,
|
||||
phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.num_classes != other.num_classes
|
||||
|| self.num_attributes != other.num_attributes
|
||||
@@ -58,7 +58,7 @@ impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
|
||||
impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
|
||||
fn f(&self, w_bias: &M) -> T {
|
||||
let mut f = T::zero();
|
||||
let (n, _) = self.x.shape();
|
||||
@@ -88,14 +88,14 @@ impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveF
|
||||
}
|
||||
}
|
||||
|
||||
struct MultiClassObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
||||
struct MultiClassObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: Vec<usize>,
|
||||
k: usize,
|
||||
phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
for MultiClassObjectiveFunction<'a, T, M>
|
||||
{
|
||||
fn f(&self, w_bias: &M) -> T {
|
||||
@@ -147,7 +147,7 @@ impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
|
||||
Reference in New Issue
Block a user