diff --git a/src/cluster/dbscan.rs b/src/cluster/dbscan.rs index 7d641cd..c793039 100644 --- a/src/cluster/dbscan.rs +++ b/src/cluster/dbscan.rs @@ -161,39 +161,60 @@ impl, T>> DBSCAN { } let mut k = 0; - let unassigned = -2; + let queued = -2; let outlier = -1; + let undefined = -3; let n = x.shape().0; - let mut y = vec![unassigned; n]; + let mut y = vec![undefined; n]; let algo = parameters .algorithm .fit(row_iter(x).collect(), parameters.distance)?; for (i, e) in row_iter(x).enumerate() { - if y[i] == unassigned { + if y[i] == undefined { let mut neighbors = algo.find_radius(&e, parameters.eps)?; if neighbors.len() < parameters.min_samples { y[i] = outlier; } else { y[i] = k; + for j in 0..neighbors.len() { - if y[neighbors[j].0] == unassigned { - y[neighbors[j].0] = k; - - let mut secondary_neighbors = - algo.find_radius(neighbors[j].2, parameters.eps)?; - - if secondary_neighbors.len() >= parameters.min_samples { - neighbors.append(&mut secondary_neighbors); - } - } - - if y[neighbors[j].0] == outlier { - y[neighbors[j].0] = k; + if y[neighbors[j].0] == undefined { + y[neighbors[j].0] = queued; } } + + while !neighbors.is_empty() { + let neighbor = neighbors.pop().unwrap(); + let index = neighbor.0; + + if y[index] == outlier { + y[index] = k; + } + + if y[index] == undefined || y[index] == queued { + y[index] = k; + + let secondary_neighbors = + algo.find_radius(neighbor.2, parameters.eps)?; + + if secondary_neighbors.len() >= parameters.min_samples { + for j in 0..secondary_neighbors.len() { + let label = y[secondary_neighbors[j].0]; + if label == undefined { + y[secondary_neighbors[j].0] = queued; + } + + if label == undefined || label == outlier { + neighbors.push(secondary_neighbors[j]); + } + } + } + } + } + k += 1; } } @@ -250,19 +271,25 @@ mod tests { &[1.0, 2.0], &[1.1, 2.1], &[0.9, 1.9], - &[1.2, 1.2], + &[1.2, 2.2], &[0.8, 1.8], &[2.0, 1.0], &[2.1, 1.1], - &[2.2, 1.2], &[1.9, 0.9], + &[2.2, 1.2], &[1.8, 0.8], &[3.0, 5.0], ]); let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0]; - let dbscan = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(1.0)).unwrap(); + let dbscan = DBSCAN::fit( + &x, + DBSCANParameters::default() + .with_eps(0.5) + .with_min_samples(2), + ) + .unwrap(); let predicted_labels = dbscan.predict(&x).unwrap(); diff --git a/src/dataset/generator.rs b/src/dataset/generator.rs index e0b2939..28a2224 100644 --- a/src/dataset/generator.rs +++ b/src/dataset/generator.rs @@ -59,8 +59,6 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset