feat: consolidates API
This commit is contained in:
+45
-21
@@ -15,8 +15,7 @@
|
||||
//! let blobs = generator::make_blobs(100, 2, 3);
|
||||
//! let x = DenseMatrix::from_vec(blobs.num_samples, blobs.num_features, &blobs.data);
|
||||
//! // Fit the algorithm and predict cluster labels
|
||||
//! let labels = DBSCAN::fit(&x, Distances::euclidian(),
|
||||
//! DBSCANParameters::default().with_eps(3.0)).
|
||||
//! let labels = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)).
|
||||
//! and_then(|dbscan| dbscan.predict(&x));
|
||||
//!
|
||||
//! println!("{:?}", labels);
|
||||
@@ -33,9 +32,11 @@ use std::iter::Sum;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::api::{Predictor, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::{row_iter, Matrix};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::distance::{Distance, Distances};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::tree::decision_tree_classifier::which_max;
|
||||
|
||||
@@ -50,7 +51,11 @@ pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// DBSCAN clustering algorithm parameters
|
||||
pub struct DBSCANParameters<T: RealNumber> {
|
||||
pub struct DBSCANParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: D,
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
pub min_samples: usize,
|
||||
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||
@@ -59,7 +64,18 @@ pub struct DBSCANParameters<T: RealNumber> {
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> DBSCANParameters<T> {
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub fn with_distance<DD: Distance<Vec<T>, T>>(self, distance: DD) -> DBSCANParameters<T, DD> {
|
||||
DBSCANParameters {
|
||||
distance,
|
||||
min_samples: self.min_samples,
|
||||
eps: self.eps,
|
||||
algorithm: self.algorithm,
|
||||
}
|
||||
}
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
pub fn with_min_samples(mut self, min_samples: usize) -> Self {
|
||||
self.min_samples = min_samples;
|
||||
@@ -86,9 +102,10 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for DBSCANParameters<T> {
|
||||
impl<T: RealNumber> Default for DBSCANParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
DBSCANParameters {
|
||||
distance: Distances::euclidian(),
|
||||
min_samples: 5,
|
||||
eps: T::half(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
@@ -96,6 +113,22 @@ impl<T: RealNumber> Default for DBSCANParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, M: Matrix<T>, D: Distance<Vec<T>, T>>
|
||||
UnsupervisedEstimator<M, DBSCANParameters<T, D>> for DBSCAN<T, D>
|
||||
{
|
||||
fn fit(x: &M, parameters: DBSCANParameters<T, D>) -> Result<Self, Failed> {
|
||||
DBSCAN::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>> Predictor<M, M::RowVector>
|
||||
for DBSCAN<T, D>
|
||||
{
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
|
||||
/// * `data` - training instances to cluster
|
||||
@@ -103,8 +136,7 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
/// * `parameters` - cluster parameters
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
distance: D,
|
||||
parameters: DBSCANParameters<T>,
|
||||
parameters: DBSCANParameters<T, D>,
|
||||
) -> Result<DBSCAN<T, D>, Failed> {
|
||||
if parameters.min_samples < 1 {
|
||||
return Err(Failed::fit(&"Invalid minPts".to_string()));
|
||||
@@ -121,7 +153,9 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
let n = x.shape().0;
|
||||
let mut y = vec![unassigned; n];
|
||||
|
||||
let algo = parameters.algorithm.fit(row_iter(x).collect(), distance)?;
|
||||
let algo = parameters
|
||||
.algorithm
|
||||
.fit(row_iter(x).collect(), parameters.distance)?;
|
||||
|
||||
for (i, e) in row_iter(x).enumerate() {
|
||||
if y[i] == unassigned {
|
||||
@@ -195,7 +229,6 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
#[test]
|
||||
fn fit_predict_dbscan() {
|
||||
@@ -215,16 +248,7 @@ mod tests {
|
||||
|
||||
let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0];
|
||||
|
||||
let dbscan = DBSCAN::fit(
|
||||
&x,
|
||||
Distances::euclidian(),
|
||||
DBSCANParameters {
|
||||
min_samples: 5,
|
||||
eps: 1.0,
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
let dbscan = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(1.0)).unwrap();
|
||||
|
||||
let predicted_labels = dbscan.predict(&x).unwrap();
|
||||
|
||||
@@ -256,7 +280,7 @@ mod tests {
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let dbscan = DBSCAN::fit(&x, Distances::euclidian(), Default::default()).unwrap();
|
||||
let dbscan = DBSCAN::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let deserialized_dbscan: DBSCAN<f64, Euclidian> =
|
||||
serde_json::from_str(&serde_json::to_string(&dbscan).unwrap()).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user