diff --git a/Cargo.toml b/Cargo.toml index 3c1b8ab..bd9db32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "smartcore" description = "Machine Learning in Rust." homepage = "https://smartcorelib.org" -version = "0.4.2" +version = "0.4.3" authors = ["smartcore Developers"] edition = "2021" license = "Apache-2.0" diff --git a/src/algorithm/neighbour/cosinepair.rs b/src/algorithm/neighbour/cosinepair.rs new file mode 100644 index 0000000..0a38f99 --- /dev/null +++ b/src/algorithm/neighbour/cosinepair.rs @@ -0,0 +1,777 @@ +/// +/// ### CosinePair: Data-structure for the dynamic closest-pair problem. +/// +/// Reference: +/// Eppstein, David: Fast hierarchical clustering and other applications of +/// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1. +/// +/// Example: +/// ``` +/// use smartcore::metrics::distance::PairwiseDistance; +/// use smartcore::linalg::basic::matrix::DenseMatrix; +/// use smartcore::algorithm::neighbour::cosinepair::CosinePair; +/// let x = DenseMatrix::::from_2d_array(&[ +/// &[5.1, 3.5, 1.4, 0.2], +/// &[4.9, 3.0, 1.4, 0.2], +/// &[4.7, 3.2, 1.3, 0.2], +/// &[4.6, 3.1, 1.5, 0.2], +/// &[5.0, 3.6, 1.4, 0.2], +/// &[5.4, 3.9, 1.7, 0.4], +/// ]).unwrap(); +/// let cosinepair = CosinePair::new(&x); +/// let closest_pair: PairwiseDistance = cosinepair.unwrap().closest_pair(); +/// ``` +/// +/// +use std::collections::HashMap; + +use num::Bounded; + +use crate::error::{Failed, FailedError}; +use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::metrics::distance::cosine::Cosine; +use crate::metrics::distance::{Distance, PairwiseDistance}; +use crate::numbers::floatnum::FloatNumber; +use crate::numbers::realnum::RealNumber; + +/// +/// Inspired by Python implementation: +/// +/// MIT License (MIT) Copyright (c) 2016 Carson Farmer +/// +/// affinity used is Cosine as it is the most used +/// +#[derive(Debug, Clone)] +pub struct CosinePair<'a, T: RealNumber + FloatNumber, M: Array2> { + /// initial matrix + pub samples: &'a M, + /// closest pair hashmap (connectivity matrix for closest pairs) + pub distances: HashMap>, + /// conga line used to keep track of the closest pair + pub neighbours: Vec, +} + +impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> { + /// Constructor + /// Instantiate and initialize the algorithm + pub fn new(m: &'a M) -> Result { + if m.shape().0 < 2 { + return Err(Failed::because( + FailedError::FindFailed, + "min number of rows should be 2", + )); + } + + let mut init = Self { + samples: m, + // to be computed in init(..) + distances: HashMap::with_capacity(m.shape().0), + neighbours: Vec::with_capacity(m.shape().0 + 1), + }; + init.init(); + Ok(init) + } + + /// Initialise `CosinePair` by passing a `Array2`. + /// Build a CosinePairs data-structure from a set of (new) points. + fn init(&mut self) { + // basic measures + let len = self.samples.shape().0; + let max_index = self.samples.shape().0 - 1; + + // Store all closest neighbors + let _distances = Box::new(HashMap::with_capacity(len)); + let _neighbours = Box::new(Vec::with_capacity(len)); + + let mut distances = *_distances; + let mut neighbours = *_neighbours; + + // fill neighbours with -1 values + neighbours.extend(0..len); + + // init closest neighbour pairwise data + for index_row_i in 0..(max_index) { + distances.insert( + index_row_i, + PairwiseDistance { + node: index_row_i, + neighbour: Option::None, + distance: Some(::max_value()), + }, + ); + } + + // loop through indeces and neighbours + for index_row_i in 0..(len) { + // start looking for the neighbour in the second element + let mut index_closest = index_row_i + 1; // closest neighbour index + let mut nbd: Option = distances[&index_row_i].distance; // init neighbour distance + for index_row_j in (index_row_i + 1)..len { + distances.insert( + index_row_j, + PairwiseDistance { + node: index_row_j, + neighbour: Some(index_row_i), + distance: nbd, + }, + ); + + let d = Cosine::new().distance( + &Vec::from_iterator( + self.samples.get_row(index_row_i).iterator(0).copied(), + self.samples.shape().1, + ), + &Vec::from_iterator( + self.samples.get_row(index_row_j).iterator(0).copied(), + self.samples.shape().1, + ), + ); + if d < nbd.unwrap().to_f64().unwrap() { + // set this j-value to be the closest neighbour + index_closest = index_row_j; + nbd = Some(T::from(d).unwrap()); + } + } + + // Add that edge + distances.entry(index_row_i).and_modify(|e| { + e.distance = nbd; + e.neighbour = Some(index_closest); + }); + } + // No more neighbors, terminate conga line. + // Last person on the line has no neigbors + distances.get_mut(&max_index).unwrap().neighbour = Some(max_index); + distances.get_mut(&(len - 1)).unwrap().distance = Some(::max_value()); + + // compute sparse matrix (connectivity matrix) + let mut sparse_matrix = M::zeros(len, len); + for (_, p) in distances.iter() { + sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap()); + } + + self.distances = distances; + self.neighbours = neighbours; + } + + /// Query k nearest neighbors for a row that's already in the dataset + pub fn query_row(&self, query_row_index: usize, k: usize) -> Result, Failed> { + if query_row_index >= self.samples.shape().0 { + return Err(Failed::because( + FailedError::FindFailed, + "Query row index out of bounds", + )); + } + + if k == 0 { + return Ok(Vec::new()); + } + + // Get distances to all other points + let mut distances = self.distances_from(query_row_index); + + // Sort by distance (ascending) + distances.sort_by(|a, b| { + a.distance + .unwrap() + .partial_cmp(&b.distance.unwrap()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // Take top k neighbors and convert to (distance, index) format + let neighbors: Vec<(T, usize)> = distances + .into_iter() + .take(k) + .map(|pd| (pd.distance.unwrap(), pd.neighbour.unwrap())) + .collect(); + + Ok(neighbors) + } + + /// Query k nearest neighbors for an external query vector + pub fn query(&self, query_vector: &Vec, k: usize) -> Result, Failed> { + if query_vector.len() != self.samples.shape().1 { + return Err(Failed::because( + FailedError::FindFailed, + "Query vector dimension mismatch", + )); + } + + if k == 0 { + return Ok(Vec::new()); + } + + // Compute distances from query vector to all points in the dataset + let mut distances = Vec::>::with_capacity(self.samples.shape().0); + + for i in 0..self.samples.shape().0 { + let dataset_point = Vec::from_iterator( + self.samples.get_row(i).iterator(0).copied(), + self.samples.shape().1, + ); + + let distance = T::from(Cosine::new().distance(query_vector, &dataset_point)).unwrap(); + + distances.push(PairwiseDistance { + node: i, // This represents the dataset point index + neighbour: Some(i), + distance: Some(distance), + }); + } + + // Sort by distance (ascending) + distances.sort_by(|a, b| { + a.distance + .unwrap() + .partial_cmp(&b.distance.unwrap()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // Take top k neighbors and convert to (distance, index) format + let neighbors: Vec<(T, usize)> = distances + .into_iter() + .take(k) + .map(|pd| (pd.distance.unwrap(), pd.node)) + .collect(); + + Ok(neighbors) + } + + /// Optimized version that reuses the existing distances_from method + /// This is more efficient for queries that are points already in the dataset + pub fn query_optimized( + &self, + query_row_index: usize, + k: usize, + ) -> Result, Failed> { + // Reuse existing method and sort the results + self.query_row(query_row_index, k) + } + + /// Find closest pair by scanning list of nearest neighbors. + #[allow(dead_code)] + pub fn closest_pair(&self) -> PairwiseDistance { + let mut a = self.neighbours[0]; // Start with first point + let mut d = self.distances[&a].distance; + for p in self.neighbours.iter() { + if self.distances[p].distance < d { + a = *p; // Update `a` and distance `d` + d = self.distances[p].distance; + } + } + let b = self.distances[&a].neighbour; + PairwiseDistance { + node: a, + neighbour: b, + distance: d, + } + } + + /// + /// Return order dissimilarities from closest to furthest + /// + #[allow(dead_code)] + pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance> { + // improvement: implement this to return `impl Iterator>` + // need to implement trait `Iterator` for `Vec<&PairwiseDistance>` + let mut distances = self + .distances + .values() + .collect::>>(); + distances.sort_by(|a, b| a.partial_cmp(b).unwrap()); + distances.into_iter() + } + + // + // Compute distances from input to all other points in data-structure. + // input is the row index of the sample matrix + // + #[allow(dead_code)] + fn distances_from(&self, index_row: usize) -> Vec> { + let mut distances = Vec::>::with_capacity(self.samples.shape().0); + for other in self.neighbours.iter() { + if index_row != *other { + distances.push(PairwiseDistance { + node: index_row, + neighbour: Some(*other), + distance: Some( + T::from(Cosine::new().distance( + &Vec::from_iterator( + self.samples.get_row(index_row).iterator(0).copied(), + self.samples.shape().1, + ), + &Vec::from_iterator( + self.samples.get_row(*other).iterator(0).copied(), + self.samples.shape().1, + ), + )) + .unwrap(), + ), + }) + } + } + distances + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; + use approx::assert_relative_eq; + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_initialization() { + let x = DenseMatrix::::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x); + + assert!(cosine_pair.is_ok()); + let cp = cosine_pair.unwrap(); + + assert_eq!(cp.samples.shape().0, 6); + assert_eq!(cp.distances.len(), 6); + assert_eq!(cp.neighbours.len(), 6); + assert!(!cp.distances.is_empty()); + assert!(!cp.neighbours.is_empty()); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_minimum_rows_error() { + // Test with only one row - should fail + let x = DenseMatrix::::from_2d_array(&[&[5.1, 3.5, 1.4, 0.2]]).unwrap(); + + let result = CosinePair::new(&x); + assert!(result.is_err()); + + if let Err(e) = result { + let expected_error = + Failed::because(FailedError::FindFailed, "min number of rows should be 2"); + assert_eq!(e, expected_error); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_closest_pair() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0], + &[0.0, 1.0], + &[1.0, 1.0], + &[2.0, 2.0], // This should be closest to [1.0, 1.0] with cosine distance + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let closest_pair = cosine_pair.closest_pair(); + + // Verify structure + assert!(closest_pair.distance.is_some()); + assert!(closest_pair.neighbour.is_some()); + + // The closest pair should have the smallest cosine distance + let distance = closest_pair.distance.unwrap(); + assert!(distance >= 0.0 && distance <= 2.0); // Cosine distance range + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_identical_vectors() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0, 3.0], + &[1.0, 2.0, 3.0], // Identical vector + &[4.0, 5.0, 6.0], + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let closest_pair = cosine_pair.closest_pair(); + + // Distance between identical vectors should be 0 + let distance = closest_pair.distance.unwrap(); + assert!((distance - 0.0).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_orthogonal_vectors() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0], + &[0.0, 1.0], // Orthogonal to first + &[2.0, 3.0], + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Check that orthogonal vectors have cosine distance of 1.0 + let distances_from_first = cosine_pair.distances_from(0); + let orthogonal_distance = distances_from_first + .iter() + .find(|pd| pd.neighbour == Some(1)) + .unwrap() + .distance + .unwrap(); + + assert!((orthogonal_distance - 1.0).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_ordered_pairs() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0], + &[2.0, 1.0], + &[3.0, 4.0], + &[4.0, 3.0], + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let ordered_pairs: Vec<_> = cosine_pair.ordered_pairs().collect(); + + assert_eq!(ordered_pairs.len(), 4); + + // Check that pairs are ordered by distance (ascending) + for i in 1..ordered_pairs.len() { + let prev_distance = ordered_pairs[i - 1].distance.unwrap(); + let curr_distance = ordered_pairs[i].distance.unwrap(); + assert!(prev_distance <= curr_distance); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_query_row() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0, 0.0], + &[0.0, 1.0, 0.0], + &[0.0, 0.0, 1.0], + &[1.0, 1.0, 0.0], + &[0.0, 1.0, 1.0], + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Query k=2 nearest neighbors for row 0 + let neighbors = cosine_pair.query_row(0, 2).unwrap(); + + assert_eq!(neighbors.len(), 2); + + // Check that distances are in ascending order + assert!(neighbors[0].0 <= neighbors[1].0); + + // All distances should be valid cosine distances (0 to 2) + for (distance, _) in &neighbors { + assert!(*distance >= 0.0 && *distance <= 2.0); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_query_row_bounds_error() { + let x = DenseMatrix::::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Query with out-of-bounds row index + let result = cosine_pair.query_row(5, 1); + assert!(result.is_err()); + + if let Err(e) = result { + let expected_error = + Failed::because(FailedError::FindFailed, "Query row index out of bounds"); + assert_eq!(e, expected_error); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_query_row_k_zero() { + let x = + DenseMatrix::::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let neighbors = cosine_pair.query_row(0, 0).unwrap(); + + assert_eq!(neighbors.len(), 0); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_query_external_vector() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0, 0.0], + &[0.0, 1.0, 0.0], + &[0.0, 0.0, 1.0], + &[1.0, 1.0, 0.0], + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Query with external vector + let query_vector = vec![1.0, 0.5, 0.0]; + let neighbors = cosine_pair.query(&query_vector, 2).unwrap(); + + assert_eq!(neighbors.len(), 2); + + // Verify distances are valid and ordered + assert!(neighbors[0].0 <= neighbors[1].0); + for (distance, index) in &neighbors { + assert!(*distance >= 0.0 && *distance <= 2.0); + assert!(*index < x.shape().0); + } + } + + #[test] + fn cosine_pair_query_dimension_mismatch() { + let x = DenseMatrix::::from_2d_array(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Query with mismatched dimensions + let query_vector = vec![1.0, 2.0]; // Only 2 dimensions, but data has 3 + let result = cosine_pair.query(&query_vector, 1); + + assert!(result.is_err()); + if let Err(e) = result { + let expected_error = + Failed::because(FailedError::FindFailed, "Query vector dimension mismatch"); + assert_eq!(e, expected_error); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_query_k_zero_external() { + let x = DenseMatrix::::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let query_vector = vec![1.0, 1.0]; + let neighbors = cosine_pair.query(&query_vector, 0).unwrap(); + + assert_eq!(neighbors.len(), 0); + } + + #[test] + fn cosine_pair_large_dataset() { + // Test with larger dataset (similar to Iris) + let x = DenseMatrix::::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + assert_eq!(cosine_pair.samples.shape().0, 15); + assert_eq!(cosine_pair.distances.len(), 15); + assert_eq!(cosine_pair.neighbours.len(), 15); + + // Test closest pair computation + let closest_pair = cosine_pair.closest_pair(); + assert!(closest_pair.distance.is_some()); + assert!(closest_pair.neighbour.is_some()); + + let distance = closest_pair.distance.unwrap(); + assert!(distance >= 0.0 && distance <= 2.0); + } + + #[test] + fn cosine_pair_float_precision() { + // Test with f32 precision + let x = DenseMatrix::::from_2d_array(&[ + &[1.0f32, 2.0, 3.0], + &[4.0f32, 5.0, 6.0], + &[7.0f32, 8.0, 9.0], + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let closest_pair = cosine_pair.closest_pair(); + + assert!(closest_pair.distance.is_some()); + let distance = closest_pair.distance.unwrap(); + assert!(distance >= 0.0 && distance <= 2.0); + + // Test querying + let neighbors = cosine_pair.query_row(0, 2).unwrap(); + assert_eq!(neighbors.len(), 2); + assert_eq!(neighbors[0].1, 1); + assert_relative_eq!(neighbors[0].0, 0.025368154); + assert_eq!(neighbors[1].1, 2); + assert_relative_eq!(neighbors[1].0, 0.040588055); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_distances_from() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0], + &[0.0, 1.0], + &[1.0, 1.0], + &[2.0, 0.0], + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let distances = cosine_pair.distances_from(0); + + // Should have 3 distances (excluding self) + assert_eq!(distances.len(), 3); + + // All should be from node 0 + for pd in &distances { + assert_eq!(pd.node, 0); + assert!(pd.neighbour.is_some()); + assert!(pd.distance.is_some()); + assert!(pd.neighbour.unwrap() != 0); // Should not include self + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_consistency_check() { + // Verify that different query methods return consistent results + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0, 3.0], + &[4.0, 5.0, 6.0], + &[7.0, 8.0, 9.0], + &[2.0, 3.0, 4.0], + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Query row 0 using internal method + let neighbors_internal = cosine_pair.query_row(0, 2).unwrap(); + + // Query row 0 using optimized method (should be same) + let neighbors_optimized = cosine_pair.query_optimized(0, 2).unwrap(); + + assert_eq!(neighbors_internal.len(), neighbors_optimized.len()); + for i in 0..neighbors_internal.len() { + let (dist1, idx1) = neighbors_internal[i]; + let (dist2, idx2) = neighbors_optimized[i]; + assert!((dist1 - dist2).abs() < 1e-10); + assert_eq!(idx1, idx2); + } + } + + // Brute force algorithm for testing/comparison + fn closest_pair_brute_force( + cosine_pair: &CosinePair<'_, f64, DenseMatrix>, + ) -> PairwiseDistance { + use itertools::Itertools; + + let m = cosine_pair.samples.shape().0; + let mut closest_pair = PairwiseDistance { + node: 0, + neighbour: None, + distance: Some(f64::MAX), + }; + + for pair in (0..m).combinations(2) { + let d = Cosine::new().distance( + &Vec::from_iterator( + cosine_pair.samples.get_row(pair[0]).iterator(0).copied(), + cosine_pair.samples.shape().1, + ), + &Vec::from_iterator( + cosine_pair.samples.get_row(pair[1]).iterator(0).copied(), + cosine_pair.samples.shape().1, + ), + ); + + if d < closest_pair.distance.unwrap() { + closest_pair.node = pair[0]; + closest_pair.neighbour = Some(pair[1]); + closest_pair.distance = Some(d); + } + } + + closest_pair + } + + #[test] + fn cosine_pair_vs_brute_force() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0, 3.0], + &[4.0, 5.0, 6.0], + &[7.0, 8.0, 9.0], + &[1.1, 2.1, 3.1], // Close to first point + ]) + .unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let cp_result = cosine_pair.closest_pair(); + let brute_result = closest_pair_brute_force(&cosine_pair); + + // Results should be identical or very close + assert!((cp_result.distance.unwrap() - brute_result.distance.unwrap()).abs() < 1e-10); + } +} diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index 3bee93a..c13e914 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -39,6 +39,8 @@ use crate::numbers::basenum::Number; use serde::{Deserialize, Serialize}; pub(crate) mod bbd_tree; +/// a variant of fastpair using cosine distance +pub mod cosinepair; /// tree data structure for fast nearest neighbor search pub mod cover_tree; /// fastpair closest neighbour algorithm diff --git a/src/algorithm/sort/quick_sort.rs b/src/algorithm/sort/quick_sort.rs index e64c424..56efec9 100644 --- a/src/algorithm/sort/quick_sort.rs +++ b/src/algorithm/sort/quick_sort.rs @@ -1,6 +1,7 @@ use num_traits::Num; pub trait QuickArgSort { + #[allow(dead_code)] fn quick_argsort_mut(&mut self) -> Vec; #[allow(dead_code)] diff --git a/src/metrics/distance/cosine.rs b/src/metrics/distance/cosine.rs new file mode 100644 index 0000000..ea065a0 --- /dev/null +++ b/src/metrics/distance/cosine.rs @@ -0,0 +1,219 @@ +//! # Cosine Distance Metric +//! +//! The cosine distance between two points \\( x \\) and \\( y \\) in n-space is defined as: +//! +//! \\[ d(x, y) = 1 - \frac{x \cdot y}{||x|| ||y||} \\] +//! +//! where \\( x \cdot y \\) is the dot product of the vectors, and \\( ||x|| \\) and \\( ||y|| \\) +//! are their respective magnitudes (Euclidean norms). +//! +//! Cosine distance measures the angular dissimilarity between vectors, ranging from 0 to 2. +//! A value of 0 indicates identical direction (parallel vectors), while larger values indicate +//! greater angular separation. +//! +//! Example: +//! +//! ``` +//! use smartcore::metrics::distance::Distance; +//! use smartcore::metrics::distance::cosine::Cosine; +//! +//! let x = vec![1., 1.]; +//! let y = vec![2., 2.]; +//! +//! let cosine_dist: f64 = Cosine::new().distance(&x, &y); +//! ``` +//! +//! +//! +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; + +use crate::linalg::basic::arrays::ArrayView1; +use crate::numbers::basenum::Number; + +use super::Distance; + +/// Cosine distance is a measure of the angular dissimilarity between two non-zero vectors in n-space. +/// It is defined as 1 minus the cosine similarity of the vectors. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct Cosine { + _t: PhantomData, +} + +impl Default for Cosine { + fn default() -> Self { + Self::new() + } +} + +impl Cosine { + /// Instantiate the initial structure + pub fn new() -> Cosine { + Cosine { _t: PhantomData } + } + + /// Calculate the dot product of two vectors using smartcore's ArrayView1 trait + #[inline] + pub(crate) fn dot_product>(x: &A, y: &A) -> f64 { + if x.shape() != y.shape() { + panic!("Input vector sizes are different."); + } + + // Use the built-in dot product method from ArrayView1 trait + x.dot(y).to_f64().unwrap() + } + + /// Calculate the squared magnitude (norm squared) of a vector + #[inline] + #[allow(dead_code)] + pub(crate) fn squared_magnitude>(x: &A) -> f64 { + x.iterator(0) + .map(|&a| { + let val = a.to_f64().unwrap(); + val * val + }) + .sum() + } + + /// Calculate the magnitude (Euclidean norm) of a vector using smartcore's norm2 method + #[inline] + pub(crate) fn magnitude>(x: &A) -> f64 { + // Use the built-in norm2 method from ArrayView1 trait + x.norm2() + } + + /// Calculate cosine similarity between two vectors + #[inline] + pub(crate) fn cosine_similarity>(x: &A, y: &A) -> f64 { + let dot_product = Self::dot_product(x, y); + let magnitude_x = Self::magnitude(x); + let magnitude_y = Self::magnitude(y); + + if magnitude_x == 0.0 || magnitude_y == 0.0 { + panic!("Cannot compute cosine distance for zero-magnitude vectors."); + } + + dot_product / (magnitude_x * magnitude_y) + } +} + +impl> Distance for Cosine { + fn distance(&self, x: &A, y: &A) -> f64 { + let similarity = Cosine::cosine_similarity(x, y); + 1.0 - similarity + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_distance_identical_vectors() { + let a = vec![1, 2, 3]; + let b = vec![1, 2, 3]; + + let dist: f64 = Cosine::new().distance(&a, &b); + + assert!((dist - 0.0).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_distance_orthogonal_vectors() { + let a = vec![1, 0]; + let b = vec![0, 1]; + + let dist: f64 = Cosine::new().distance(&a, &b); + + assert!((dist - 1.0).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_distance_opposite_vectors() { + let a = vec![1, 2, 3]; + let b = vec![-1, -2, -3]; + + let dist: f64 = Cosine::new().distance(&a, &b); + + assert!((dist - 2.0).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_distance_general_case() { + let a = vec![1.0, 2.0, 3.0]; + let b = vec![2.0, 1.0, 3.0]; + + let dist: f64 = Cosine::new().distance(&a, &b); + + // Expected cosine similarity: (1*2 + 2*1 + 3*3) / (sqrt(1+4+9) * sqrt(4+1+9)) + // = (2 + 2 + 9) / (sqrt(14) * sqrt(14)) = 13/14 ≈ 0.9286 + // So cosine distance = 1 - 13/14 = 1/14 ≈ 0.0714 + let expected_dist = 1.0 - (13.0 / 14.0); + assert!((dist - expected_dist).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + #[should_panic(expected = "Input vector sizes are different.")] + fn cosine_distance_different_sizes() { + let a = vec![1, 2]; + let b = vec![1, 2, 3]; + + let _dist: f64 = Cosine::new().distance(&a, &b); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + #[should_panic(expected = "Cannot compute cosine distance for zero-magnitude vectors.")] + fn cosine_distance_zero_vector() { + let a = vec![0, 0, 0]; + let b = vec![1, 2, 3]; + + let _dist: f64 = Cosine::new().distance(&a, &b); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_distance_float_precision() { + let a = vec![1.0f32, 2.0, 3.0]; + let b = vec![4.0f32, 5.0, 6.0]; + + let dist: f64 = Cosine::new().distance(&a, &b); + + // Calculate expected value manually + let dot_product = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; // = 32 + let mag_a = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0_f64).sqrt(); // = sqrt(14) + let mag_b = (4.0 * 4.0 + 5.0 * 5.0 + 6.0 * 6.0_f64).sqrt(); // = sqrt(77) + let expected_similarity = dot_product / (mag_a * mag_b); + let expected_distance = 1.0 - expected_similarity; + + assert!((dist - expected_distance).abs() < 1e-6); + } +} diff --git a/src/metrics/distance/mod.rs b/src/metrics/distance/mod.rs index 193d7a1..6fdbaa4 100644 --- a/src/metrics/distance/mod.rs +++ b/src/metrics/distance/mod.rs @@ -13,6 +13,8 @@ //! //! +/// Cosine distance +pub mod cosine; /// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points. pub mod euclidian; /// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different. diff --git a/src/optimization/line_search.rs b/src/optimization/line_search.rs index 8357d8d..98d2982 100644 --- a/src/optimization/line_search.rs +++ b/src/optimization/line_search.rs @@ -6,8 +6,8 @@ pub trait LineSearchMethod { /// Find alpha that satisfies strong Wolfe conditions. fn search( &self, - f: &(dyn Fn(T) -> T), - df: &(dyn Fn(T) -> T), + f: &dyn Fn(T) -> T, + df: &dyn Fn(T) -> T, alpha: T, f0: T, df0: T, @@ -55,8 +55,8 @@ impl Default for Backtracking { impl LineSearchMethod for Backtracking { fn search( &self, - f: &(dyn Fn(T) -> T), - _: &(dyn Fn(T) -> T), + f: &dyn Fn(T) -> T, + _: &dyn Fn(T) -> T, alpha: T, f0: T, df0: T, diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 5679516..9600767 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -674,15 +674,20 @@ impl, Y: Array1> ) -> bool { let (n_rows, n_attr) = visitor.x.shape(); - let mut label = Option::None; + let mut label = None; let mut is_pure = true; for i in 0..n_rows { if visitor.samples[i] > 0 { - if label.is_none() { - label = Option::Some(visitor.y[i]); - } else if visitor.y[i] != label.unwrap() { - is_pure = false; - break; + match label { + None => { + label = Some(visitor.y[i]); + } + Some(current_label) => { + if visitor.y[i] != current_label { + is_pure = false; + break; + } + } } } } diff --git a/src/xgboost/xgb_regressor.rs b/src/xgboost/xgb_regressor.rs index ac6ec75..75c77a5 100644 --- a/src/xgboost/xgb_regressor.rs +++ b/src/xgboost/xgb_regressor.rs @@ -96,7 +96,7 @@ impl Objective { pub fn gradient>(&self, y_true: &Y, y_pred: &Vec) -> Vec { match self { Objective::MeanSquaredError => zip(y_true.iterator(0), y_pred) - .map(|(true_val, pred_val)| (*pred_val - true_val.to_f64().unwrap())) + .map(|(true_val, pred_val)| *pred_val - true_val.to_f64().unwrap()) .collect(), } }