fix: renames FloatExt to RealNumber
This commit is contained in:
@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
@@ -18,7 +18,7 @@ pub struct DecisionTreeClassifierParameters {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||
pub struct DecisionTreeClassifier<T: RealNumber> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
num_classes: usize,
|
||||
@@ -34,7 +34,7 @@ pub enum SplitCriterion {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: FloatExt> {
|
||||
pub struct Node<T: RealNumber> {
|
||||
index: usize,
|
||||
output: usize,
|
||||
split_feature: usize,
|
||||
@@ -44,7 +44,7 @@ pub struct Node<T: FloatExt> {
|
||||
false_child: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
|
||||
impl<T: RealNumber> PartialEq for DecisionTreeClassifier<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth
|
||||
|| self.num_classes != other.num_classes
|
||||
@@ -67,7 +67,7 @@ impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
impl<T: RealNumber> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.output == other.output
|
||||
&& self.split_feature == other.split_feature
|
||||
@@ -95,7 +95,7 @@ impl Default for DecisionTreeClassifierParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> Node<T> {
|
||||
impl<T: RealNumber> Node<T> {
|
||||
fn new(index: usize, output: usize) -> Self {
|
||||
Node {
|
||||
index: index,
|
||||
@@ -109,7 +109,7 @@ impl<T: FloatExt> Node<T> {
|
||||
}
|
||||
}
|
||||
|
||||
struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
||||
struct NodeVisitor<'a, T: RealNumber, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: &'a Vec<usize>,
|
||||
node: usize,
|
||||
@@ -121,7 +121,7 @@ struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
||||
phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> T {
|
||||
fn impurity<T: RealNumber>(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> T {
|
||||
let mut impurity = T::zero();
|
||||
|
||||
match criterion {
|
||||
@@ -156,7 +156,7 @@ fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usiz
|
||||
return impurity;
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
impl<'a, T: RealNumber, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
fn new(
|
||||
node_id: usize,
|
||||
samples: Vec<usize>,
|
||||
@@ -193,7 +193,7 @@ pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
|
||||
return which;
|
||||
}
|
||||
|
||||
impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
|
||||
@@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeRegressorParameters {
|
||||
@@ -16,14 +16,14 @@ pub struct DecisionTreeRegressorParameters {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeRegressor<T: FloatExt> {
|
||||
pub struct DecisionTreeRegressor<T: RealNumber> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
depth: u16,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: FloatExt> {
|
||||
pub struct Node<T: RealNumber> {
|
||||
index: usize,
|
||||
output: T,
|
||||
split_feature: usize,
|
||||
@@ -43,7 +43,7 @@ impl Default for DecisionTreeRegressorParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> Node<T> {
|
||||
impl<T: RealNumber> Node<T> {
|
||||
fn new(index: usize, output: T) -> Self {
|
||||
Node {
|
||||
index: index,
|
||||
@@ -57,7 +57,7 @@ impl<T: FloatExt> Node<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
impl<T: RealNumber> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
(self.output - other.output).abs() < T::epsilon()
|
||||
&& self.split_feature == other.split_feature
|
||||
@@ -74,7 +74,7 @@ impl<T: FloatExt> PartialEq for Node<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
|
||||
impl<T: RealNumber> PartialEq for DecisionTreeRegressor<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth || self.nodes.len() != other.nodes.len() {
|
||||
return false;
|
||||
@@ -89,7 +89,7 @@ impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
||||
struct NodeVisitor<'a, T: RealNumber, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: &'a M,
|
||||
node: usize,
|
||||
@@ -100,7 +100,7 @@ struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
||||
level: u16,
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
impl<'a, T: RealNumber, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
fn new(
|
||||
node_id: usize,
|
||||
samples: Vec<usize>,
|
||||
@@ -122,7 +122,7 @@ impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
|
||||
Reference in New Issue
Block a user