From 0e89113297e440b621870830a440d38aa6a7ff34 Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Thu, 20 Feb 2020 18:43:24 -0800 Subject: [PATCH] feat: adds KMeans clustering algorithm --- smartcore.iml | 10 +- src/algorithm/neighbour/bbd_tree.rs | 345 ++++++++++++++++++++++ src/algorithm/neighbour/mod.rs | 1 + src/classification/decision_tree.rs | 4 +- src/classification/logistic_regression.rs | 17 +- src/classification/random_forest.rs | 2 +- src/cluster/kmeans.rs | 220 ++++++++++++++ src/cluster/mod.rs | 1 + src/lib.rs | 1 + src/linalg/mod.rs | 4 - src/linalg/naive/dense_matrix.rs | 82 +++-- src/linalg/ndarray_bindings.rs | 21 -- src/regression/linear_regression.rs | 13 +- 13 files changed, 637 insertions(+), 84 deletions(-) create mode 100644 src/algorithm/neighbour/bbd_tree.rs create mode 100644 src/cluster/kmeans.rs create mode 100644 src/cluster/mod.rs diff --git a/smartcore.iml b/smartcore.iml index 8021953..7fe828a 100644 --- a/smartcore.iml +++ b/smartcore.iml @@ -1,8 +1,14 @@ - + - + + + + + + + diff --git a/src/algorithm/neighbour/bbd_tree.rs b/src/algorithm/neighbour/bbd_tree.rs new file mode 100644 index 0000000..c7a7e86 --- /dev/null +++ b/src/algorithm/neighbour/bbd_tree.rs @@ -0,0 +1,345 @@ +use std::collections::LinkedList; + +use crate::linalg::Matrix; + +#[derive(Debug)] +pub struct BBDTree { + nodes: Vec, + index: Vec, + root: usize +} + +#[derive(Debug)] +struct BBDTreeNode { + count: usize, + index: usize, + center: Vec, + radius: Vec, + sum: Vec, + cost: f64, + lower: Option, + upper: Option +} + +impl BBDTreeNode { + fn new(d: usize) -> BBDTreeNode { + BBDTreeNode { + count: 0, + index: 0, + center: vec![0f64; d], + radius: vec![0f64; d], + sum: vec![0f64; d], + cost: 0f64, + lower: Option::None, + upper: Option::None + } + } +} + +impl BBDTree { + pub fn new(data: &M) -> BBDTree { + let nodes = Vec::new(); + + let (n, _) = data.shape(); + + let mut index = vec![0; n]; + for i in 0..n { + index[i] = i; + } + + let mut tree = BBDTree{ + nodes: nodes, + index: index, + root: 0 + }; + + let root = tree.build_node(data, 0, n); + + tree.root = root; + + tree + } + + pub(in crate) fn clustering(&self, centroids: &Vec>, sums: &mut Vec>, counts: &mut Vec, membership: &mut Vec) -> f64 { + let k = centroids.len(); + + counts.iter_mut().for_each(|x| *x = 0); + let mut candidates = vec![0; k]; + for i in 0..k { + candidates[i] = i; + sums[i].iter_mut().for_each(|x| *x = 0f64); + } + + self.filter(self.root, centroids, &candidates, k, sums, counts, membership) + } + + fn filter(&self, node: usize, centroids: &Vec>, candidates: &Vec, k: usize, sums: &mut Vec>, counts: &mut Vec, membership: &mut Vec) -> f64{ + let d = centroids[0].len(); + + // Determine which mean the node mean is closest to + let mut min_dist = BBDTree::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]); + let mut closest = candidates[0]; + for i in 1..k { + let dist = BBDTree::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]); + if dist < min_dist { + min_dist = dist; + closest = candidates[i]; + } + } + + // If this is a non-leaf node, recurse if necessary + if !self.nodes[node].lower.is_none() { + // Build the new list of candidates + let mut new_candidates = vec![0;k]; + let mut newk = 0; + + for i in 0..k { + if !BBDTree::prune(&self.nodes[node].center, &self.nodes[node].radius, ¢roids, closest, candidates[i]) { + new_candidates[newk] = candidates[i]; + newk += 1; + } + } + + // Recurse if there's at least two + if newk > 1 { + let result = self.filter(self.nodes[node].lower.unwrap(), centroids, &mut new_candidates, newk, sums, counts, membership) + + self.filter(self.nodes[node].upper.unwrap(), centroids, &mut new_candidates, newk, sums, counts, membership); + return result; + } + } + + // Assigns all data within this node to a single mean + for i in 0..d { + sums[closest][i] += self.nodes[node].sum[i]; + } + + counts[closest] += self.nodes[node].count; + + let last = self.nodes[node].index + self.nodes[node].count; + for i in self.nodes[node].index..last { + membership[self.index[i]] = closest; + } + + BBDTree::node_cost(&self.nodes[node], ¢roids[closest]) + + } + + fn prune(center: &Vec, radius: &Vec, centroids: &Vec>, best_index: usize, test_index: usize) -> bool { + if best_index == test_index { + return false; + } + + let d = centroids[0].len(); + + let best = ¢roids[best_index]; + let test = ¢roids[test_index]; + let mut lhs = 0f64; + let mut rhs = 0f64; + for i in 0..d { + let diff = test[i] - best[i]; + lhs += diff * diff; + if diff > 0f64 { + rhs += (center[i] + radius[i] - best[i]) * diff; + } else { + rhs += (center[i] - radius[i] - best[i]) * diff; + } + } + + return lhs >= 2f64 * rhs; + } + + fn squared_distance(x: &Vec,y: &Vec) -> f64 { + if x.len() != y.len() { + panic!("Input vector sizes are different."); + } + + let mut sum = 0f64; + for i in 0..x.len() { + sum += (x[i] - y[i]).powf(2.); + } + + return sum; + } + + fn build_node(&mut self, data: &M, begin: usize, end: usize) -> usize { + let (_, d) = data.shape(); + + // Allocate the node + let mut node = BBDTreeNode::new(d); + + // Fill in basic info + node.count = end - begin; + node.index = begin; + + // Calculate the bounding box + let mut lower_bound = vec![0f64; d]; + let mut upper_bound = vec![0f64; d]; + + for i in 0..d { + lower_bound[i] = data.get(self.index[begin],i); + upper_bound[i] = data.get(self.index[begin],i); + } + + for i in begin..end { + for j in 0..d { + let c = data.get(self.index[i], j); + if lower_bound[j] > c { + lower_bound[j] = c; + } + if upper_bound[j] < c { + upper_bound[j] = c; + } + } + } + + // Calculate bounding box stats + let mut max_radius = -1.; + let mut split_index = 0; + for i in 0..d { + node.center[i] = (lower_bound[i] + upper_bound[i]) / 2.; + node.radius[i] = (upper_bound[i] - lower_bound[i]) / 2.; + if node.radius[i] > max_radius { + max_radius = node.radius[i]; + split_index = i; + } + } + + // If the max spread is 0, make this a leaf node + if max_radius < 1E-10 { + node.lower = Option::None; + node.upper = Option::None; + for i in 0..d { + node.sum[i] = data.get(self.index[begin], i); + } + + if end > begin + 1 { + let len = end - begin; + for i in 0..d { + node.sum[i] *= len as f64; + } + } + + node.cost = 0f64; + return self.add_node(node); + } + + // Partition the data around the midpoint in this dimension. The + // partitioning is done in-place by iterating from left-to-right and + // right-to-left in the same way that partioning is done in quicksort. + let split_cutoff = node.center[split_index]; + let mut i1 = begin; + let mut i2 = end - 1; + let mut size = 0; + while i1 <= i2 { + let mut i1_good = data.get(self.index[i1], split_index) < split_cutoff; + let mut i2_good = data.get(self.index[i2], split_index) >= split_cutoff; + + if !i1_good && !i2_good { + let temp = self.index[i1]; + self.index[i1] = self.index[i2]; + self.index[i2] = temp; + i1_good = true; + i2_good = true; + } + + if i1_good { + i1 += 1; + size += 1; + } + + if i2_good { + i2 -= 1; + } + } + + // Create the child nodes + node.lower = Option::Some(self.build_node(data, begin, begin + size)); + node.upper = Option::Some(self.build_node(data, begin + size, end)); + + // Calculate the new sum and opt cost + for i in 0..d { + node.sum[i] = self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i]; + } + + let mut mean = vec![0f64; d]; + for i in 0..d { + mean[i] = node.sum[i] / node.count as f64; + } + + node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean) + BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean); + + self.add_node(node) + } + + fn node_cost(node: &BBDTreeNode, center: &Vec) -> f64 { + let d = center.len(); + let mut scatter = 0f64; + for i in 0..d { + let x = (node.sum[i] / node.count as f64) - center[i]; + scatter += x * x; + } + node.cost + node.count as f64 * scatter + } + + fn add_node(&mut self, new_node: BBDTreeNode) -> usize{ + let idx = self.nodes.len(); + self.nodes.push(new_node); + idx + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::DenseMatrix; + + #[test] + fn fit_predict_iris() { + + let data = DenseMatrix::from_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4]]); + + let tree = BBDTree::new(&data); + + let centroids = vec![ + vec![4.86, 3.22, 1.61, 0.29], + vec![6.23, 2.92, 4.48, 1.42] + ]; + + let mut sums = vec![ + vec![0f64; 4], + vec![0f64; 4] + ]; + + let mut counts = vec![11, 9]; + + let mut membership = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1]; + + let dist = tree.clustering(¢roids, &mut sums, &mut counts, &mut membership); + assert!((dist - 10.68).abs() < 1e-2); + assert!((sums[0][0] - 48.6).abs() < 1e-2); + assert!((sums[1][3] - 13.8).abs() < 1e-2); + assert_eq!(membership[17], 1); + + } + +} \ No newline at end of file diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index ff6a070..0ab7e19 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -1,5 +1,6 @@ pub mod cover_tree; pub mod linear_search; +pub mod bbd_tree; pub enum KNNAlgorithmName { CoverTree, diff --git a/src/classification/decision_tree.rs b/src/classification/decision_tree.rs index 70240a1..80f9a4a 100644 --- a/src/classification/decision_tree.rs +++ b/src/classification/decision_tree.rs @@ -412,7 +412,7 @@ mod tests { #[test] fn fit_predict_iris() { - let x = DenseMatrix::from_2d_array(&[ + let x = DenseMatrix::from_array(&[ &[5.1, 3.5, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2], &[4.7, 3.2, 1.3, 0.2], @@ -444,7 +444,7 @@ mod tests { #[test] fn fit_predict_baloons() { - let x = DenseMatrix::from_2d_array(&[ + let x = DenseMatrix::from_array(&[ &[1.,1.,1.,0.], &[1.,1.,1.,0.], &[1.,1.,1.,1.], diff --git a/src/classification/logistic_regression.rs b/src/classification/logistic_regression.rs index b979499..eece7d1 100644 --- a/src/classification/logistic_regression.rs +++ b/src/classification/logistic_regression.rs @@ -192,19 +192,26 @@ impl LogisticRegression { } pub fn predict(&self, x: &M) -> M::RowVector { + let n = x.shape().0; + let mut result = M::zeros(1, n); if self.num_classes == 2 { let (nrows, _) = x.shape(); let x_and_bias = x.v_stack(&M::ones(nrows, 1)); let y_hat: Vec = x_and_bias.dot(&self.weights.transpose()).to_raw_vector(); - M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect()).to_row_vector() + for i in 0..n { + result.set(0, i, self.classes[if y_hat[i].sigmoid() > 0.5 { 1 } else { 0 }]); + } } else { let (nrows, _) = x.shape(); let x_and_bias = x.v_stack(&M::ones(nrows, 1)); let y_hat = x_and_bias.dot(&self.weights.transpose()); let class_idxs = y_hat.argmax(); - M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect()).to_row_vector() + for i in 0..n { + result.set(0, i, self.classes[class_idxs[i]]); + } } + result.to_row_vector() } pub fn coefficients(&self) -> M { @@ -242,7 +249,7 @@ mod tests { #[test] fn multiclass_objective_f() { - let x = DenseMatrix::from_2d_array(&[ + let x = DenseMatrix::from_array(&[ &[1., -5.], &[ 2., 5.], &[ 3., -2.], @@ -282,7 +289,7 @@ mod tests { #[test] fn binary_objective_f() { - let x = DenseMatrix::from_2d_array(&[ + let x = DenseMatrix::from_array(&[ &[1., -5.], &[ 2., 5.], &[ 3., -2.], @@ -323,7 +330,7 @@ mod tests { #[test] fn lr_fit_predict() { - let x = DenseMatrix::from_2d_array(&[ + let x = DenseMatrix::from_array(&[ &[1., -5.], &[ 2., 5.], &[ 3., -2.], diff --git a/src/classification/random_forest.rs b/src/classification/random_forest.rs index 5a46e59..78d31ec 100644 --- a/src/classification/random_forest.rs +++ b/src/classification/random_forest.rs @@ -128,7 +128,7 @@ mod tests { #[test] fn fit_predict_iris() { - let x = DenseMatrix::from_2d_array(&[ + let x = DenseMatrix::from_array(&[ &[5.1, 3.5, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2], &[4.7, 3.2, 1.3, 0.2], diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs new file mode 100644 index 0000000..f59efbd --- /dev/null +++ b/src/cluster/kmeans.rs @@ -0,0 +1,220 @@ +extern crate rand; + +use rand::Rng; + +use crate::linalg::Matrix; +use crate::algorithm::neighbour::bbd_tree::BBDTree; + +#[derive(Debug)] +pub struct KMeans { + k: usize, + y: Vec, + size: Vec, + distortion: f64, + centroids: Vec> +} + +#[derive(Debug, Clone)] +pub struct KMeansParameters { + pub max_iter: usize +} + +impl Default for KMeansParameters { + fn default() -> Self { + KMeansParameters { + max_iter: 100 + } + } +} + +impl KMeans{ + pub fn new(data: &M, k: usize, parameters: KMeansParameters) -> KMeans { + + let bbd = BBDTree::new(data); + + if k < 2 { + panic!("Invalid number of clusters: {}", k); + } + + if parameters.max_iter <= 0 { + panic!("Invalid maximum number of iterations: {}", parameters.max_iter); + } + + let (n, d) = data.shape(); + + let mut distortion = std::f64::MAX; + let mut y = KMeans::kmeans_plus_plus(data, k); + let mut size = vec![0; k]; + let mut centroids = vec![vec![0f64; d]; k]; + + for i in 0..n { + size[y[i]] += 1; + } + + for i in 0..n { + for j in 0..d { + centroids[y[i]][j] += data.get(i, j); + } + } + + for i in 0..k { + for j in 0..d { + centroids[i][j] /= size[i] as f64; + } + } + + let mut sums = vec![vec![0f64; d]; k]; + for _ in 1..= parameters.max_iter { + let dist = bbd.clustering(¢roids, &mut sums, &mut size, &mut y); + for i in 0..k { + if size[i] > 0 { + for j in 0..d { + centroids[i][j] = sums[i][j] as f64 / size[i] as f64; + } + } + } + + if distortion <= dist { + break; + } else { + distortion = dist; + } + + } + + KMeans{ + k: k, + y: y, + size: size, + distortion: distortion, + centroids: centroids + } + } + + pub fn predict(&self, x: &M) -> M::RowVector { + let (n, _) = x.shape(); + let mut result = M::zeros(1, n); + + for i in 0..n { + + let mut min_dist = std::f64::MAX; + let mut best_cluster = 0; + + for j in 0..self.k { + let dist = KMeans::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]); + if dist < min_dist { + min_dist = dist; + best_cluster = j; + } + } + result.set(0, i, best_cluster as f64); + } + + result.to_row_vector() + } + + fn kmeans_plus_plus(data: &M, k: usize) -> Vec{ + let mut rng = rand::thread_rng(); + let (n, _) = data.shape(); + let mut y = vec![0; n]; + let mut centroid = data.get_row_as_vec(rng.gen_range(0, n)); + + let mut d = vec![std::f64::MAX; n]; + + // pick the next center + for j in 1..k { + // Loop over the samples and compare them to the most recent center. Store + // the distance from each sample to its closest center in scores. + for i in 0..n { + // compute the distance between this sample and the current center + let dist = KMeans::squared_distance(&data.get_row_as_vec(i), ¢roid); + + if dist < d[i] { + d[i] = dist; + y[i] = j - 1; + } + } + + let sum: f64 = d.iter().sum(); + let cutoff = rng.gen::() * sum; + let mut cost = 0f64; + let index = 0; + for index in 0..n { + cost += d[index]; + if cost >= cutoff { + break; + } + } + + centroid = data.get_row_as_vec(index); + } + + for i in 0..n { + // compute the distance between this sample and the current center + let dist = KMeans::squared_distance(&data.get_row_as_vec(i), ¢roid); + + if dist < d[i] { + d[i] = dist; + y[i] = k - 1; + } + } + + y + } + + fn squared_distance(x: &Vec,y: &Vec) -> f64 { + if x.len() != y.len() { + panic!("Input vector sizes are different."); + } + + let mut sum = 0f64; + for i in 0..x.len() { + sum += (x[i] - y[i]).powf(2.); + } + + return sum; + } + +} + + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::DenseMatrix; + + #[test] + fn fit_predict_iris() { + let x = DenseMatrix::from_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4]]); + + let kmeans = KMeans::new(&x, 2, Default::default()); + + let y = kmeans.predict(&x); + + for i in 0..y.len() { + assert_eq!(y[i] as usize, kmeans.y[i]); + } + + } + +} \ No newline at end of file diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs new file mode 100644 index 0000000..6e28466 --- /dev/null +++ b/src/cluster/mod.rs @@ -0,0 +1 @@ +pub mod kmeans; \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index cc66764..6dac826 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod classification; pub mod regression; +pub mod cluster; pub mod linalg; pub mod math; pub mod error; diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 636e1ac..798cc90 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -12,10 +12,6 @@ pub trait Matrix: Clone + Debug { fn to_row_vector(self) -> Self::RowVector; - fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self; - - fn from_vec(nrows: usize, ncols: usize, values: Vec) -> Self; - fn get(&self, row: usize, col: usize) -> f64; fn get_row_as_vec(&self, row: usize) -> Vec; diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index 431a1d2..772afcd 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -12,13 +12,21 @@ pub struct DenseMatrix { } -impl DenseMatrix { +impl DenseMatrix { + + fn new(nrows: usize, ncols: usize, values: Vec) -> DenseMatrix { + DenseMatrix { + ncols: ncols, + nrows: nrows, + values: values + } + } - pub fn from_2d_array(values: &[&[f64]]) -> DenseMatrix { - DenseMatrix::from_2d_vec(&values.into_iter().map(|row| Vec::from(*row)).collect()) + pub fn from_array(values: &[&[f64]]) -> DenseMatrix { + DenseMatrix::from_vec(&values.into_iter().map(|row| Vec::from(*row)).collect()) } - pub fn from_2d_vec(values: &Vec>) -> DenseMatrix { + pub fn from_vec(values: &Vec>) -> DenseMatrix { let nrows = values.len(); let ncols = values.first().unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector")).len(); let mut m = DenseMatrix { @@ -112,24 +120,12 @@ impl Matrix for DenseMatrix { type RowVector = Vec; fn from_row_vector(vec: Self::RowVector) -> Self{ - DenseMatrix::from_vec(1, vec.len(), vec) + DenseMatrix::new(1, vec.len(), vec) } fn to_row_vector(self) -> Self::RowVector{ self.to_raw_vector() - } - - fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> DenseMatrix { - DenseMatrix::from_vec(nrows, ncols, Vec::from(values)) - } - - fn from_vec(nrows: usize, ncols: usize, values: Vec) -> DenseMatrix { - DenseMatrix { - ncols: ncols, - nrows: nrows, - values: values - } - } + } fn get(&self, row: usize, col: usize) -> f64 { self.values[col*self.nrows + row] @@ -255,7 +251,7 @@ impl Matrix for DenseMatrix { let ncols = cols.len(); let nrows = rows.len(); - let mut m = DenseMatrix::from_vec(nrows, ncols, vec![0f64; nrows * ncols]); + let mut m = DenseMatrix::new(nrows, ncols, vec![0f64; nrows * ncols]); for r in rows.start..rows.end { for c in cols.start..cols.end { @@ -731,7 +727,7 @@ impl Matrix for DenseMatrix { } fn fill(nrows: usize, ncols: usize, value: f64) -> Self { - DenseMatrix::from_vec(nrows, ncols, vec![value; ncols * nrows]) + DenseMatrix::new(nrows, ncols, vec![value; ncols * nrows]) } fn add_mut(&mut self, other: &Self) -> &Self { @@ -998,7 +994,7 @@ mod tests { fn from_to_row_vec() { let vec = vec![ 1., 2., 3.]; - assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::from_vec(1, 3, vec![1., 2., 3.])); + assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::new(1, 3, vec![1., 2., 3.])); assert_eq!(DenseMatrix::from_row_vector(vec.clone()).to_row_vector(), vec![1., 2., 3.]); } @@ -1006,9 +1002,9 @@ mod tests { #[test] fn qr_solve_mut() { - let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); - let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]); - let expected_w = DenseMatrix::from_array(3, 2, &[-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]); + let mut a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); + let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]); + let expected_w = DenseMatrix::new(3, 2, vec![-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]); let w = a.qr_solve_mut(b); assert!(w.approximate_eq(&expected_w, 1e-2)); } @@ -1016,9 +1012,9 @@ mod tests { #[test] fn svd_solve_mut() { - let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); - let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]); - let expected_w = DenseMatrix::from_array(3, 2, &[-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]); + let mut a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); + let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]); + let expected_w = DenseMatrix::new(3, 2, vec![-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]); let w = a.svd_solve_mut(b); assert!(w.approximate_eq(&expected_w, 1e-2)); } @@ -1026,16 +1022,16 @@ mod tests { #[test] fn h_stack() { - let a = DenseMatrix::from_2d_array( + let a = DenseMatrix::from_array( &[ &[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]); - let b = DenseMatrix::from_2d_array( + let b = DenseMatrix::from_array( &[ &[1., 2., 3.], &[4., 5., 6.]]); - let expected = DenseMatrix::from_2d_array( + let expected = DenseMatrix::from_array( &[ &[1., 2., 3.], &[4., 5., 6.], @@ -1049,17 +1045,17 @@ mod tests { #[test] fn v_stack() { - let a = DenseMatrix::from_2d_array( + let a = DenseMatrix::from_array( &[ &[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]); - let b = DenseMatrix::from_2d_array( + let b = DenseMatrix::from_array( &[ &[1., 2.], &[3., 4.], &[5., 6.]]); - let expected = DenseMatrix::from_2d_array( + let expected = DenseMatrix::from_array( &[ &[1., 2., 3., 1., 2.], &[4., 5., 6., 3., 4.], @@ -1071,16 +1067,16 @@ mod tests { #[test] fn dot() { - let a = DenseMatrix::from_2d_array( + let a = DenseMatrix::from_array( &[ &[1., 2., 3.], &[4., 5., 6.]]); - let b = DenseMatrix::from_2d_array( + let b = DenseMatrix::from_array( &[ &[1., 2.], &[3., 4.], &[5., 6.]]); - let expected = DenseMatrix::from_2d_array( + let expected = DenseMatrix::from_array( &[ &[22., 28.], &[49., 64.]]); @@ -1091,12 +1087,12 @@ mod tests { #[test] fn slice() { - let m = DenseMatrix::from_2d_array( + let m = DenseMatrix::from_array( &[ &[1., 2., 3., 1., 2.], &[4., 5., 6., 3., 4.], &[7., 8., 9., 5., 6.]]); - let expected = DenseMatrix::from_2d_array( + let expected = DenseMatrix::from_array( &[ &[2., 3.], &[5., 6.]]); @@ -1107,15 +1103,15 @@ mod tests { #[test] fn approximate_eq() { - let m = DenseMatrix::from_2d_array( + let m = DenseMatrix::from_array( &[ &[2., 3.], &[5., 6.]]); - let m_eq = DenseMatrix::from_2d_array( + let m_eq = DenseMatrix::from_array( &[ &[2.5, 3.0], &[5., 5.5]]); - let m_neq = DenseMatrix::from_2d_array( + let m_neq = DenseMatrix::from_array( &[ &[3.0, 3.0], &[5., 6.5]]); @@ -1135,8 +1131,8 @@ mod tests { #[test] fn transpose() { - let m = DenseMatrix::from_2d_array(&[&[1.0, 3.0], &[2.0, 4.0]]); - let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]); + let m = DenseMatrix::from_array(&[&[1.0, 3.0], &[2.0, 4.0]]); + let expected = DenseMatrix::from_array(&[&[1.0, 2.0], &[3.0, 4.0]]); let m_transposed = m.transpose(); for c in 0..2 { for r in 0..2 { diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 6f36552..13947ad 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -16,14 +16,6 @@ impl Matrix for ArrayBase, Ix2> self.into_shape(vec_size).unwrap() } - fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self { - Array::from_shape_vec((nrows, ncols), values.to_vec()).unwrap() - } - - fn from_vec(nrows: usize, ncols: usize, values: Vec) -> Self { - Array::from_shape_vec((nrows, ncols), values).unwrap() - } - fn get(&self, row: usize, col: usize) -> f64 { self[[row, col]] } @@ -330,19 +322,6 @@ mod tests { } - #[test] - fn from_array_from_vec() { - - let a1 = arr2(&[[ 1., 2., 3.], - [4., 5., 6.]]); - let a2 = Array2::from_array(2, 3, &[1., 2., 3., 4., 5., 6.]); - let a3 = Array2::from_vec(2, 3, vec![1., 2., 3., 4., 5., 6.]); - - assert_eq!(a1, a2); - assert_eq!(a1, a3); - - } - #[test] fn vstack_hstack() { diff --git a/src/regression/linear_regression.rs b/src/regression/linear_regression.rs index e042dd4..ae3f742 100644 --- a/src/regression/linear_regression.rs +++ b/src/regression/linear_regression.rs @@ -19,14 +19,14 @@ impl LinearRegression { pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression{ + let b = y.transpose(); let (x_nrows, num_attributes) = x.shape(); - let (y_nrows, _) = y.shape(); + let (y_nrows, _) = b.shape(); if x_nrows != y_nrows { panic!("Number of rows of X doesn't match number of rows of Y"); } - - let b = y.clone(); + let mut a = x.v_stack(&M::ones(x_nrows, 1)); let w = match solver { @@ -52,7 +52,7 @@ impl Regression for LinearRegression { let (nrows, _) = x.shape(); let mut y_hat = x.dot(&self.coefficients); y_hat.add_mut(&M::fill(nrows, 1, self.intercept)); - y_hat + y_hat.transpose() } } @@ -65,7 +65,7 @@ mod tests { #[test] fn ols_fit_predict() { - let x = DenseMatrix::from_2d_array(&[ + let x = DenseMatrix::from_array(&[ &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], @@ -82,7 +82,8 @@ mod tests { &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); - let y = DenseMatrix::from_array(16, 1, &[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]); + + let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]); let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);