diff --git a/src/dataset/generator.rs b/src/dataset/generator.rs index 28a2224..4d454af 100644 --- a/src/dataset/generator.rs +++ b/src/dataset/generator.rs @@ -88,6 +88,44 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset Dataset { + + let num_samples_out = num_samples / 2; + let num_samples_in = num_samples - num_samples_out; + + let linspace_out = linspace(0.0, std::f32::consts::PI, num_samples_out); + let linspace_in = linspace(0.0, std::f32::consts::PI, num_samples_in); + + let noise = Normal::new(0.0, noise).unwrap(); + let mut rng = rand::thread_rng(); + + let mut x: Vec = Vec::with_capacity(num_samples * 2); + let mut y: Vec = Vec::with_capacity(num_samples); + + for v in linspace_out { + x.push(v.cos() + noise.sample(&mut rng)); + x.push(v.sin() + noise.sample(&mut rng)); + y.push(0.0); + } + + for v in linspace_in { + x.push(1.0 - v.cos() + noise.sample(&mut rng)); + x.push(1.0 - v.sin() + noise.sample(&mut rng) - 0.5); + y.push(1.0); + } + + Dataset { + data: x, + target: y, + num_samples, + num_features: 2, + feature_names: (0..2).map(|n| n.to_string()).collect(), + target_names: vec!["label".to_string()], + description: "Two interleaving half circles in 2d".to_string(), + } +} + fn linspace(start: f32, stop: f32, num: usize) -> Vec { let div = num as f32; let delta = stop - start; @@ -123,4 +161,17 @@ mod tests { assert_eq!(dataset.num_features, 2); assert_eq!(dataset.num_samples, 10); } + + #[test] + fn test_make_moons() { + let dataset = make_moons(100, 0.05); + println!("{:?}", dataset.data); + assert_eq!( + dataset.data.len(), + dataset.num_features * dataset.num_samples + ); + assert_eq!(dataset.target.len(), dataset.num_samples); + assert_eq!(dataset.num_features, 2); + assert_eq!(dataset.num_samples, 10); + } }