fix: fixes a bug in DBSCAN, removes println's

This commit is contained in:
Volodymyr Orlov
2021-01-02 18:08:40 -08:00
parent c5a7beaf0e
commit bb9a05b993
3 changed files with 48 additions and 23 deletions
+46 -19
View File
@@ -161,39 +161,60 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
}
let mut k = 0;
let unassigned = -2;
let queued = -2;
let outlier = -1;
let undefined = -3;
let n = x.shape().0;
let mut y = vec![unassigned; n];
let mut y = vec![undefined; n];
let algo = parameters
.algorithm
.fit(row_iter(x).collect(), parameters.distance)?;
for (i, e) in row_iter(x).enumerate() {
if y[i] == unassigned {
if y[i] == undefined {
let mut neighbors = algo.find_radius(&e, parameters.eps)?;
if neighbors.len() < parameters.min_samples {
y[i] = outlier;
} else {
y[i] = k;
for j in 0..neighbors.len() {
if y[neighbors[j].0] == unassigned {
y[neighbors[j].0] = k;
let mut secondary_neighbors =
algo.find_radius(neighbors[j].2, parameters.eps)?;
if secondary_neighbors.len() >= parameters.min_samples {
neighbors.append(&mut secondary_neighbors);
}
}
if y[neighbors[j].0] == outlier {
y[neighbors[j].0] = k;
if y[neighbors[j].0] == undefined {
y[neighbors[j].0] = queued;
}
}
while !neighbors.is_empty() {
let neighbor = neighbors.pop().unwrap();
let index = neighbor.0;
if y[index] == outlier {
y[index] = k;
}
if y[index] == undefined || y[index] == queued {
y[index] = k;
let secondary_neighbors =
algo.find_radius(neighbor.2, parameters.eps)?;
if secondary_neighbors.len() >= parameters.min_samples {
for j in 0..secondary_neighbors.len() {
let label = y[secondary_neighbors[j].0];
if label == undefined {
y[secondary_neighbors[j].0] = queued;
}
if label == undefined || label == outlier {
neighbors.push(secondary_neighbors[j]);
}
}
}
}
}
k += 1;
}
}
@@ -250,19 +271,25 @@ mod tests {
&[1.0, 2.0],
&[1.1, 2.1],
&[0.9, 1.9],
&[1.2, 1.2],
&[1.2, 2.2],
&[0.8, 1.8],
&[2.0, 1.0],
&[2.1, 1.1],
&[2.2, 1.2],
&[1.9, 0.9],
&[2.2, 1.2],
&[1.8, 0.8],
&[3.0, 5.0],
]);
let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0];
let dbscan = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(1.0)).unwrap();
let dbscan = DBSCAN::fit(
&x,
DBSCANParameters::default()
.with_eps(0.5)
.with_min_samples(2),
)
.unwrap();
let predicted_labels = dbscan.predict(&x).unwrap();