Chris McComb
2021-02-17 21:22:06 -05:00
parent a30802ec43
commit 4fb2625a33
+51
View File
@@ -88,6 +88,44 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
} }
} }
/// Make two interleaving half circles in 2d
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, f32> {
let num_samples_out = num_samples / 2;
let num_samples_in = num_samples - num_samples_out;
let linspace_out = linspace(0.0, std::f32::consts::PI, num_samples_out);
let linspace_in = linspace(0.0, std::f32::consts::PI, num_samples_in);
let noise = Normal::new(0.0, noise).unwrap();
let mut rng = rand::thread_rng();
let mut x: Vec<f32> = Vec::with_capacity(num_samples * 2);
let mut y: Vec<f32> = Vec::with_capacity(num_samples);
for v in linspace_out {
x.push(v.cos() + noise.sample(&mut rng));
x.push(v.sin() + noise.sample(&mut rng));
y.push(0.0);
}
for v in linspace_in {
x.push(1.0 - v.cos() + noise.sample(&mut rng));
x.push(1.0 - v.sin() + noise.sample(&mut rng) - 0.5);
y.push(1.0);
}
Dataset {
data: x,
target: y,
num_samples,
num_features: 2,
feature_names: (0..2).map(|n| n.to_string()).collect(),
target_names: vec!["label".to_string()],
description: "Two interleaving half circles in 2d".to_string(),
}
}
fn linspace(start: f32, stop: f32, num: usize) -> Vec<f32> { fn linspace(start: f32, stop: f32, num: usize) -> Vec<f32> {
let div = num as f32; let div = num as f32;
let delta = stop - start; let delta = stop - start;
@@ -123,4 +161,17 @@ mod tests {
assert_eq!(dataset.num_features, 2); assert_eq!(dataset.num_features, 2);
assert_eq!(dataset.num_samples, 10); assert_eq!(dataset.num_samples, 10);
} }
#[test]
fn test_make_moons() {
let dataset = make_moons(100, 0.05);
println!("{:?}", dataset.data);
assert_eq!(
dataset.data.len(),
dataset.num_features * dataset.num_samples
);
assert_eq!(dataset.target.len(), dataset.num_samples);
assert_eq!(dataset.num_features, 2);
assert_eq!(dataset.num_samples, 10);
}
} }