make default params available to serde (#167)
* add seed param to search params * make default params available to serde * lints * create defaults for enums * lint
This commit is contained in:
@@ -49,16 +49,21 @@ use crate::neighbors::KNNWeightFunction;
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: D,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub weight: KNNWeightFunction,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub k: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// this parameter is not used
|
||||
t: PhantomData<T>,
|
||||
}
|
||||
@@ -111,8 +116,8 @@ impl<T: RealNumber> Default for KNNClassifierParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
KNNClassifierParameters {
|
||||
distance: Distances::euclidian(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
weight: KNNWeightFunction::Uniform,
|
||||
algorithm: KNNAlgorithmName::default(),
|
||||
weight: KNNWeightFunction::default(),
|
||||
k: 3,
|
||||
t: PhantomData,
|
||||
}
|
||||
|
||||
@@ -52,16 +52,21 @@ use crate::neighbors::KNNWeightFunction;
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
distance: D,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub weight: KNNWeightFunction,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub k: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// this parameter is not used
|
||||
t: PhantomData<T>,
|
||||
}
|
||||
@@ -113,8 +118,8 @@ impl<T: RealNumber> Default for KNNRegressorParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
KNNRegressorParameters {
|
||||
distance: Distances::euclidian(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
weight: KNNWeightFunction::Uniform,
|
||||
algorithm: KNNAlgorithmName::default(),
|
||||
weight: KNNWeightFunction::default(),
|
||||
k: 3,
|
||||
t: PhantomData,
|
||||
}
|
||||
|
||||
@@ -58,6 +58,12 @@ pub enum KNNWeightFunction {
|
||||
Distance,
|
||||
}
|
||||
|
||||
impl Default for KNNWeightFunction {
|
||||
fn default() -> Self {
|
||||
KNNWeightFunction::Uniform
|
||||
}
|
||||
}
|
||||
|
||||
impl KNNWeightFunction {
|
||||
fn calc_weights<T: RealNumber>(&self, distances: Vec<T>) -> std::vec::Vec<T> {
|
||||
match *self {
|
||||
|
||||
Reference in New Issue
Block a user