fix: fixes a bug in DBSCAN, removes println's

This commit is contained in:
Volodymyr Orlov
2021-01-02 18:08:40 -08:00
parent c5a7beaf0e
commit bb9a05b993
3 changed files with 48 additions and 23 deletions
+46 -19
View File
@@ -161,39 +161,60 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
}
let mut k = 0;
let unassigned = -2;
let queued = -2;
let outlier = -1;
let undefined = -3;
let n = x.shape().0;
let mut y = vec![unassigned; n];
let mut y = vec![undefined; n];
let algo = parameters
.algorithm
.fit(row_iter(x).collect(), parameters.distance)?;
for (i, e) in row_iter(x).enumerate() {
if y[i] == unassigned {
if y[i] == undefined {
let mut neighbors = algo.find_radius(&e, parameters.eps)?;
if neighbors.len() < parameters.min_samples {
y[i] = outlier;
} else {
y[i] = k;
for j in 0..neighbors.len() {
if y[neighbors[j].0] == unassigned {
y[neighbors[j].0] = k;
let mut secondary_neighbors =
algo.find_radius(neighbors[j].2, parameters.eps)?;
if secondary_neighbors.len() >= parameters.min_samples {
neighbors.append(&mut secondary_neighbors);
}
}
if y[neighbors[j].0] == outlier {
y[neighbors[j].0] = k;
if y[neighbors[j].0] == undefined {
y[neighbors[j].0] = queued;
}
}
while !neighbors.is_empty() {
let neighbor = neighbors.pop().unwrap();
let index = neighbor.0;
if y[index] == outlier {
y[index] = k;
}
if y[index] == undefined || y[index] == queued {
y[index] = k;
let secondary_neighbors =
algo.find_radius(neighbor.2, parameters.eps)?;
if secondary_neighbors.len() >= parameters.min_samples {
for j in 0..secondary_neighbors.len() {
let label = y[secondary_neighbors[j].0];
if label == undefined {
y[secondary_neighbors[j].0] = queued;
}
if label == undefined || label == outlier {
neighbors.push(secondary_neighbors[j]);
}
}
}
}
}
k += 1;
}
}
@@ -250,19 +271,25 @@ mod tests {
&[1.0, 2.0],
&[1.1, 2.1],
&[0.9, 1.9],
&[1.2, 1.2],
&[1.2, 2.2],
&[0.8, 1.8],
&[2.0, 1.0],
&[2.1, 1.1],
&[2.2, 1.2],
&[1.9, 0.9],
&[2.2, 1.2],
&[1.8, 0.8],
&[3.0, 5.0],
]);
let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0];
let dbscan = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(1.0)).unwrap();
let dbscan = DBSCAN::fit(
&x,
DBSCANParameters::default()
.with_eps(0.5)
.with_min_samples(2),
)
.unwrap();
let predicted_labels = dbscan.predict(&x).unwrap();
-3
View File
@@ -59,8 +59,6 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
let linspace_out = linspace(0.0, 2.0 * std::f32::consts::PI, num_samples_out);
let linspace_in = linspace(0.0, 2.0 * std::f32::consts::PI, num_samples_in);
println!("{:?}", linspace_out);
println!("{:?}", linspace_in);
let noise = Normal::new(0.0, noise).unwrap();
let mut rng = rand::thread_rng();
@@ -117,7 +115,6 @@ mod tests {
#[test]
fn test_make_circles() {
let dataset = make_circles(10, 0.5, 0.05);
println!("{:?}", dataset.as_matrix());
assert_eq!(
dataset.data.len(),
dataset.num_features * dataset.num_samples
+2 -1
View File
@@ -34,7 +34,8 @@
//! &[5.2, 2.7, 3.9, 1.4],
//! ]);
//!
//! let svd = SVD::fit(&iris, SVDParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2
//! let svd = SVD::fit(&iris, SVDParameters::default().
//! with_n_components(2)).unwrap(); // Reduce number of features to 2
//!
//! let iris_reduced = svd.transform(&iris).unwrap();
//!