Implemented make_moons generator per https://github.com/scikit-learn/scikit-learn/blob/95119c13a/sklearn/datasets/_samples_generator.py#L683
This commit is contained in:
@@ -88,6 +88,44 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
|
||||
}
|
||||
}
|
||||
|
||||
/// Make two interleaving half circles in 2d
|
||||
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, f32> {
|
||||
|
||||
let num_samples_out = num_samples / 2;
|
||||
let num_samples_in = num_samples - num_samples_out;
|
||||
|
||||
let linspace_out = linspace(0.0, std::f32::consts::PI, num_samples_out);
|
||||
let linspace_in = linspace(0.0, std::f32::consts::PI, num_samples_in);
|
||||
|
||||
let noise = Normal::new(0.0, noise).unwrap();
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut x: Vec<f32> = Vec::with_capacity(num_samples * 2);
|
||||
let mut y: Vec<f32> = Vec::with_capacity(num_samples);
|
||||
|
||||
for v in linspace_out {
|
||||
x.push(v.cos() + noise.sample(&mut rng));
|
||||
x.push(v.sin() + noise.sample(&mut rng));
|
||||
y.push(0.0);
|
||||
}
|
||||
|
||||
for v in linspace_in {
|
||||
x.push(1.0 - v.cos() + noise.sample(&mut rng));
|
||||
x.push(1.0 - v.sin() + noise.sample(&mut rng) - 0.5);
|
||||
y.push(1.0);
|
||||
}
|
||||
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
num_samples,
|
||||
num_features: 2,
|
||||
feature_names: (0..2).map(|n| n.to_string()).collect(),
|
||||
target_names: vec!["label".to_string()],
|
||||
description: "Two interleaving half circles in 2d".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn linspace(start: f32, stop: f32, num: usize) -> Vec<f32> {
|
||||
let div = num as f32;
|
||||
let delta = stop - start;
|
||||
@@ -123,4 +161,17 @@ mod tests {
|
||||
assert_eq!(dataset.num_features, 2);
|
||||
assert_eq!(dataset.num_samples, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_make_moons() {
|
||||
let dataset = make_moons(100, 0.05);
|
||||
println!("{:?}", dataset.data);
|
||||
assert_eq!(
|
||||
dataset.data.len(),
|
||||
dataset.num_features * dataset.num_samples
|
||||
);
|
||||
assert_eq!(dataset.target.len(), dataset.num_samples);
|
||||
assert_eq!(dataset.num_features, 2);
|
||||
assert_eq!(dataset.num_samples, 10);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user