make default params available to serde (#167)

* add seed param to search params

* make default params available to serde

* lints

* create defaults for enums

* lint
This commit is contained in:
Montana Low
2022-09-21 19:48:31 -07:00
committed by morenol
parent 05dfffad5c
commit f4fd4d2239
22 changed files with 175 additions and 18 deletions
+7 -2
View File
@@ -49,16 +49,21 @@ use crate::neighbors::KNNWeightFunction;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
#[cfg_attr(feature = "serde", serde(default))]
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
pub distance: D,
#[cfg_attr(feature = "serde", serde(default))]
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
pub algorithm: KNNAlgorithmName,
#[cfg_attr(feature = "serde", serde(default))]
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
pub weight: KNNWeightFunction,
#[cfg_attr(feature = "serde", serde(default))]
/// number of training samples to consider when estimating class for new point. Default value is 3.
pub k: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// this parameter is not used
t: PhantomData<T>,
}
@@ -111,8 +116,8 @@ impl<T: RealNumber> Default for KNNClassifierParameters<T, Euclidian> {
fn default() -> Self {
KNNClassifierParameters {
distance: Distances::euclidian(),
algorithm: KNNAlgorithmName::CoverTree,
weight: KNNWeightFunction::Uniform,
algorithm: KNNAlgorithmName::default(),
weight: KNNWeightFunction::default(),
k: 3,
t: PhantomData,
}
+7 -2
View File
@@ -52,16 +52,21 @@ use crate::neighbors::KNNWeightFunction;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
#[cfg_attr(feature = "serde", serde(default))]
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
distance: D,
#[cfg_attr(feature = "serde", serde(default))]
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
pub algorithm: KNNAlgorithmName,
#[cfg_attr(feature = "serde", serde(default))]
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
pub weight: KNNWeightFunction,
#[cfg_attr(feature = "serde", serde(default))]
/// number of training samples to consider when estimating class for new point. Default value is 3.
pub k: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// this parameter is not used
t: PhantomData<T>,
}
@@ -113,8 +118,8 @@ impl<T: RealNumber> Default for KNNRegressorParameters<T, Euclidian> {
fn default() -> Self {
KNNRegressorParameters {
distance: Distances::euclidian(),
algorithm: KNNAlgorithmName::CoverTree,
weight: KNNWeightFunction::Uniform,
algorithm: KNNAlgorithmName::default(),
weight: KNNWeightFunction::default(),
k: 3,
t: PhantomData,
}
+6
View File
@@ -58,6 +58,12 @@ pub enum KNNWeightFunction {
Distance,
}
impl Default for KNNWeightFunction {
fn default() -> Self {
KNNWeightFunction::Uniform
}
}
impl KNNWeightFunction {
fn calc_weights<T: RealNumber>(&self, distances: Vec<T>) -> std::vec::Vec<T> {
match *self {