fix: fixes a bug in DBSCAN, removes println's
This commit is contained in:
+46
-19
@@ -161,39 +161,60 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
}
|
||||
|
||||
let mut k = 0;
|
||||
let unassigned = -2;
|
||||
let queued = -2;
|
||||
let outlier = -1;
|
||||
let undefined = -3;
|
||||
|
||||
let n = x.shape().0;
|
||||
let mut y = vec![unassigned; n];
|
||||
let mut y = vec![undefined; n];
|
||||
|
||||
let algo = parameters
|
||||
.algorithm
|
||||
.fit(row_iter(x).collect(), parameters.distance)?;
|
||||
|
||||
for (i, e) in row_iter(x).enumerate() {
|
||||
if y[i] == unassigned {
|
||||
if y[i] == undefined {
|
||||
let mut neighbors = algo.find_radius(&e, parameters.eps)?;
|
||||
if neighbors.len() < parameters.min_samples {
|
||||
y[i] = outlier;
|
||||
} else {
|
||||
y[i] = k;
|
||||
|
||||
for j in 0..neighbors.len() {
|
||||
if y[neighbors[j].0] == unassigned {
|
||||
y[neighbors[j].0] = k;
|
||||
|
||||
let mut secondary_neighbors =
|
||||
algo.find_radius(neighbors[j].2, parameters.eps)?;
|
||||
|
||||
if secondary_neighbors.len() >= parameters.min_samples {
|
||||
neighbors.append(&mut secondary_neighbors);
|
||||
}
|
||||
}
|
||||
|
||||
if y[neighbors[j].0] == outlier {
|
||||
y[neighbors[j].0] = k;
|
||||
if y[neighbors[j].0] == undefined {
|
||||
y[neighbors[j].0] = queued;
|
||||
}
|
||||
}
|
||||
|
||||
while !neighbors.is_empty() {
|
||||
let neighbor = neighbors.pop().unwrap();
|
||||
let index = neighbor.0;
|
||||
|
||||
if y[index] == outlier {
|
||||
y[index] = k;
|
||||
}
|
||||
|
||||
if y[index] == undefined || y[index] == queued {
|
||||
y[index] = k;
|
||||
|
||||
let secondary_neighbors =
|
||||
algo.find_radius(neighbor.2, parameters.eps)?;
|
||||
|
||||
if secondary_neighbors.len() >= parameters.min_samples {
|
||||
for j in 0..secondary_neighbors.len() {
|
||||
let label = y[secondary_neighbors[j].0];
|
||||
if label == undefined {
|
||||
y[secondary_neighbors[j].0] = queued;
|
||||
}
|
||||
|
||||
if label == undefined || label == outlier {
|
||||
neighbors.push(secondary_neighbors[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
k += 1;
|
||||
}
|
||||
}
|
||||
@@ -250,19 +271,25 @@ mod tests {
|
||||
&[1.0, 2.0],
|
||||
&[1.1, 2.1],
|
||||
&[0.9, 1.9],
|
||||
&[1.2, 1.2],
|
||||
&[1.2, 2.2],
|
||||
&[0.8, 1.8],
|
||||
&[2.0, 1.0],
|
||||
&[2.1, 1.1],
|
||||
&[2.2, 1.2],
|
||||
&[1.9, 0.9],
|
||||
&[2.2, 1.2],
|
||||
&[1.8, 0.8],
|
||||
&[3.0, 5.0],
|
||||
]);
|
||||
|
||||
let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0];
|
||||
|
||||
let dbscan = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(1.0)).unwrap();
|
||||
let dbscan = DBSCAN::fit(
|
||||
&x,
|
||||
DBSCANParameters::default()
|
||||
.with_eps(0.5)
|
||||
.with_min_samples(2),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let predicted_labels = dbscan.predict(&x).unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user