feat: adds KMeans clustering algorithm
This commit is contained in:
+8
-2
@@ -1,8 +1,14 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<module type="WEB_MODULE" version="4">
|
<module type="RUST_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager" inherit-compiler-output="true">
|
<component name="NewModuleRootManager" inherit-compiler-output="true">
|
||||||
<exclude-output />
|
<exclude-output />
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/examples" isTestSource="false" />
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/benches" isTestSource="true" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/target" />
|
||||||
|
</content>
|
||||||
<orderEntry type="inheritedJdk" />
|
<orderEntry type="inheritedJdk" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
|
|||||||
@@ -0,0 +1,345 @@
|
|||||||
|
use std::collections::LinkedList;
|
||||||
|
|
||||||
|
use crate::linalg::Matrix;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct BBDTree {
|
||||||
|
nodes: Vec<BBDTreeNode>,
|
||||||
|
index: Vec<usize>,
|
||||||
|
root: usize
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct BBDTreeNode {
|
||||||
|
count: usize,
|
||||||
|
index: usize,
|
||||||
|
center: Vec<f64>,
|
||||||
|
radius: Vec<f64>,
|
||||||
|
sum: Vec<f64>,
|
||||||
|
cost: f64,
|
||||||
|
lower: Option<usize>,
|
||||||
|
upper: Option<usize>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BBDTreeNode {
|
||||||
|
fn new(d: usize) -> BBDTreeNode {
|
||||||
|
BBDTreeNode {
|
||||||
|
count: 0,
|
||||||
|
index: 0,
|
||||||
|
center: vec![0f64; d],
|
||||||
|
radius: vec![0f64; d],
|
||||||
|
sum: vec![0f64; d],
|
||||||
|
cost: 0f64,
|
||||||
|
lower: Option::None,
|
||||||
|
upper: Option::None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BBDTree {
|
||||||
|
pub fn new<M: Matrix>(data: &M) -> BBDTree {
|
||||||
|
let nodes = Vec::new();
|
||||||
|
|
||||||
|
let (n, _) = data.shape();
|
||||||
|
|
||||||
|
let mut index = vec![0; n];
|
||||||
|
for i in 0..n {
|
||||||
|
index[i] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut tree = BBDTree{
|
||||||
|
nodes: nodes,
|
||||||
|
index: index,
|
||||||
|
root: 0
|
||||||
|
};
|
||||||
|
|
||||||
|
let root = tree.build_node(data, 0, n);
|
||||||
|
|
||||||
|
tree.root = root;
|
||||||
|
|
||||||
|
tree
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(in crate) fn clustering(&self, centroids: &Vec<Vec<f64>>, sums: &mut Vec<Vec<f64>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> f64 {
|
||||||
|
let k = centroids.len();
|
||||||
|
|
||||||
|
counts.iter_mut().for_each(|x| *x = 0);
|
||||||
|
let mut candidates = vec![0; k];
|
||||||
|
for i in 0..k {
|
||||||
|
candidates[i] = i;
|
||||||
|
sums[i].iter_mut().for_each(|x| *x = 0f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.filter(self.root, centroids, &candidates, k, sums, counts, membership)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter(&self, node: usize, centroids: &Vec<Vec<f64>>, candidates: &Vec<usize>, k: usize, sums: &mut Vec<Vec<f64>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> f64{
|
||||||
|
let d = centroids[0].len();
|
||||||
|
|
||||||
|
// Determine which mean the node mean is closest to
|
||||||
|
let mut min_dist = BBDTree::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||||
|
let mut closest = candidates[0];
|
||||||
|
for i in 1..k {
|
||||||
|
let dist = BBDTree::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||||
|
if dist < min_dist {
|
||||||
|
min_dist = dist;
|
||||||
|
closest = candidates[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is a non-leaf node, recurse if necessary
|
||||||
|
if !self.nodes[node].lower.is_none() {
|
||||||
|
// Build the new list of candidates
|
||||||
|
let mut new_candidates = vec![0;k];
|
||||||
|
let mut newk = 0;
|
||||||
|
|
||||||
|
for i in 0..k {
|
||||||
|
if !BBDTree::prune(&self.nodes[node].center, &self.nodes[node].radius, ¢roids, closest, candidates[i]) {
|
||||||
|
new_candidates[newk] = candidates[i];
|
||||||
|
newk += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse if there's at least two
|
||||||
|
if newk > 1 {
|
||||||
|
let result = self.filter(self.nodes[node].lower.unwrap(), centroids, &mut new_candidates, newk, sums, counts, membership) +
|
||||||
|
self.filter(self.nodes[node].upper.unwrap(), centroids, &mut new_candidates, newk, sums, counts, membership);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assigns all data within this node to a single mean
|
||||||
|
for i in 0..d {
|
||||||
|
sums[closest][i] += self.nodes[node].sum[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
counts[closest] += self.nodes[node].count;
|
||||||
|
|
||||||
|
let last = self.nodes[node].index + self.nodes[node].count;
|
||||||
|
for i in self.nodes[node].index..last {
|
||||||
|
membership[self.index[i]] = closest;
|
||||||
|
}
|
||||||
|
|
||||||
|
BBDTree::node_cost(&self.nodes[node], ¢roids[closest])
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prune(center: &Vec<f64>, radius: &Vec<f64>, centroids: &Vec<Vec<f64>>, best_index: usize, test_index: usize) -> bool {
|
||||||
|
if best_index == test_index {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let d = centroids[0].len();
|
||||||
|
|
||||||
|
let best = ¢roids[best_index];
|
||||||
|
let test = ¢roids[test_index];
|
||||||
|
let mut lhs = 0f64;
|
||||||
|
let mut rhs = 0f64;
|
||||||
|
for i in 0..d {
|
||||||
|
let diff = test[i] - best[i];
|
||||||
|
lhs += diff * diff;
|
||||||
|
if diff > 0f64 {
|
||||||
|
rhs += (center[i] + radius[i] - best[i]) * diff;
|
||||||
|
} else {
|
||||||
|
rhs += (center[i] - radius[i] - best[i]) * diff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lhs >= 2f64 * rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn squared_distance(x: &Vec<f64>,y: &Vec<f64>) -> f64 {
|
||||||
|
if x.len() != y.len() {
|
||||||
|
panic!("Input vector sizes are different.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sum = 0f64;
|
||||||
|
for i in 0..x.len() {
|
||||||
|
sum += (x[i] - y[i]).powf(2.);
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_node<M: Matrix>(&mut self, data: &M, begin: usize, end: usize) -> usize {
|
||||||
|
let (_, d) = data.shape();
|
||||||
|
|
||||||
|
// Allocate the node
|
||||||
|
let mut node = BBDTreeNode::new(d);
|
||||||
|
|
||||||
|
// Fill in basic info
|
||||||
|
node.count = end - begin;
|
||||||
|
node.index = begin;
|
||||||
|
|
||||||
|
// Calculate the bounding box
|
||||||
|
let mut lower_bound = vec![0f64; d];
|
||||||
|
let mut upper_bound = vec![0f64; d];
|
||||||
|
|
||||||
|
for i in 0..d {
|
||||||
|
lower_bound[i] = data.get(self.index[begin],i);
|
||||||
|
upper_bound[i] = data.get(self.index[begin],i);
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in begin..end {
|
||||||
|
for j in 0..d {
|
||||||
|
let c = data.get(self.index[i], j);
|
||||||
|
if lower_bound[j] > c {
|
||||||
|
lower_bound[j] = c;
|
||||||
|
}
|
||||||
|
if upper_bound[j] < c {
|
||||||
|
upper_bound[j] = c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate bounding box stats
|
||||||
|
let mut max_radius = -1.;
|
||||||
|
let mut split_index = 0;
|
||||||
|
for i in 0..d {
|
||||||
|
node.center[i] = (lower_bound[i] + upper_bound[i]) / 2.;
|
||||||
|
node.radius[i] = (upper_bound[i] - lower_bound[i]) / 2.;
|
||||||
|
if node.radius[i] > max_radius {
|
||||||
|
max_radius = node.radius[i];
|
||||||
|
split_index = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the max spread is 0, make this a leaf node
|
||||||
|
if max_radius < 1E-10 {
|
||||||
|
node.lower = Option::None;
|
||||||
|
node.upper = Option::None;
|
||||||
|
for i in 0..d {
|
||||||
|
node.sum[i] = data.get(self.index[begin], i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if end > begin + 1 {
|
||||||
|
let len = end - begin;
|
||||||
|
for i in 0..d {
|
||||||
|
node.sum[i] *= len as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
node.cost = 0f64;
|
||||||
|
return self.add_node(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Partition the data around the midpoint in this dimension. The
|
||||||
|
// partitioning is done in-place by iterating from left-to-right and
|
||||||
|
// right-to-left in the same way that partioning is done in quicksort.
|
||||||
|
let split_cutoff = node.center[split_index];
|
||||||
|
let mut i1 = begin;
|
||||||
|
let mut i2 = end - 1;
|
||||||
|
let mut size = 0;
|
||||||
|
while i1 <= i2 {
|
||||||
|
let mut i1_good = data.get(self.index[i1], split_index) < split_cutoff;
|
||||||
|
let mut i2_good = data.get(self.index[i2], split_index) >= split_cutoff;
|
||||||
|
|
||||||
|
if !i1_good && !i2_good {
|
||||||
|
let temp = self.index[i1];
|
||||||
|
self.index[i1] = self.index[i2];
|
||||||
|
self.index[i2] = temp;
|
||||||
|
i1_good = true;
|
||||||
|
i2_good = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if i1_good {
|
||||||
|
i1 += 1;
|
||||||
|
size += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if i2_good {
|
||||||
|
i2 -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the child nodes
|
||||||
|
node.lower = Option::Some(self.build_node(data, begin, begin + size));
|
||||||
|
node.upper = Option::Some(self.build_node(data, begin + size, end));
|
||||||
|
|
||||||
|
// Calculate the new sum and opt cost
|
||||||
|
for i in 0..d {
|
||||||
|
node.sum[i] = self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut mean = vec![0f64; d];
|
||||||
|
for i in 0..d {
|
||||||
|
mean[i] = node.sum[i] / node.count as f64;
|
||||||
|
}
|
||||||
|
|
||||||
|
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean) + BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
|
||||||
|
|
||||||
|
self.add_node(node)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn node_cost(node: &BBDTreeNode, center: &Vec<f64>) -> f64 {
|
||||||
|
let d = center.len();
|
||||||
|
let mut scatter = 0f64;
|
||||||
|
for i in 0..d {
|
||||||
|
let x = (node.sum[i] / node.count as f64) - center[i];
|
||||||
|
scatter += x * x;
|
||||||
|
}
|
||||||
|
node.cost + node.count as f64 * scatter
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_node(&mut self, new_node: BBDTreeNode) -> usize{
|
||||||
|
let idx = self.nodes.len();
|
||||||
|
self.nodes.push(new_node);
|
||||||
|
idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fit_predict_iris() {
|
||||||
|
|
||||||
|
let data = DenseMatrix::from_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
|
&[5.2, 2.7, 3.9, 1.4]]);
|
||||||
|
|
||||||
|
let tree = BBDTree::new(&data);
|
||||||
|
|
||||||
|
let centroids = vec![
|
||||||
|
vec![4.86, 3.22, 1.61, 0.29],
|
||||||
|
vec![6.23, 2.92, 4.48, 1.42]
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut sums = vec![
|
||||||
|
vec![0f64; 4],
|
||||||
|
vec![0f64; 4]
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut counts = vec![11, 9];
|
||||||
|
|
||||||
|
let mut membership = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1];
|
||||||
|
|
||||||
|
let dist = tree.clustering(¢roids, &mut sums, &mut counts, &mut membership);
|
||||||
|
assert!((dist - 10.68).abs() < 1e-2);
|
||||||
|
assert!((sums[0][0] - 48.6).abs() < 1e-2);
|
||||||
|
assert!((sums[1][3] - 13.8).abs() < 1e-2);
|
||||||
|
assert_eq!(membership[17], 1);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
pub mod cover_tree;
|
pub mod cover_tree;
|
||||||
pub mod linear_search;
|
pub mod linear_search;
|
||||||
|
pub mod bbd_tree;
|
||||||
|
|
||||||
pub enum KNNAlgorithmName {
|
pub enum KNNAlgorithmName {
|
||||||
CoverTree,
|
CoverTree,
|
||||||
|
|||||||
@@ -412,7 +412,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
&[4.7, 3.2, 1.3, 0.2],
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
@@ -444,7 +444,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_baloons() {
|
fn fit_predict_baloons() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[1.,1.,1.,0.],
|
&[1.,1.,1.,0.],
|
||||||
&[1.,1.,1.,0.],
|
&[1.,1.,1.,0.],
|
||||||
&[1.,1.,1.,1.],
|
&[1.,1.,1.,1.],
|
||||||
|
|||||||
@@ -192,19 +192,26 @@ impl<M: Matrix> LogisticRegression<M> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn predict(&self, x: &M) -> M::RowVector {
|
pub fn predict(&self, x: &M) -> M::RowVector {
|
||||||
|
let n = x.shape().0;
|
||||||
|
let mut result = M::zeros(1, n);
|
||||||
if self.num_classes == 2 {
|
if self.num_classes == 2 {
|
||||||
let (nrows, _) = x.shape();
|
let (nrows, _) = x.shape();
|
||||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||||
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
|
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
|
||||||
M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect()).to_row_vector()
|
for i in 0..n {
|
||||||
|
result.set(0, i, self.classes[if y_hat[i].sigmoid() > 0.5 { 1 } else { 0 }]);
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
let (nrows, _) = x.shape();
|
let (nrows, _) = x.shape();
|
||||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||||
let y_hat = x_and_bias.dot(&self.weights.transpose());
|
let y_hat = x_and_bias.dot(&self.weights.transpose());
|
||||||
let class_idxs = y_hat.argmax();
|
let class_idxs = y_hat.argmax();
|
||||||
M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect()).to_row_vector()
|
for i in 0..n {
|
||||||
|
result.set(0, i, self.classes[class_idxs[i]]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
result.to_row_vector()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn coefficients(&self) -> M {
|
pub fn coefficients(&self) -> M {
|
||||||
@@ -242,7 +249,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn multiclass_objective_f() {
|
fn multiclass_objective_f() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[1., -5.],
|
&[1., -5.],
|
||||||
&[ 2., 5.],
|
&[ 2., 5.],
|
||||||
&[ 3., -2.],
|
&[ 3., -2.],
|
||||||
@@ -282,7 +289,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn binary_objective_f() {
|
fn binary_objective_f() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[1., -5.],
|
&[1., -5.],
|
||||||
&[ 2., 5.],
|
&[ 2., 5.],
|
||||||
&[ 3., -2.],
|
&[ 3., -2.],
|
||||||
@@ -323,7 +330,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn lr_fit_predict() {
|
fn lr_fit_predict() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[1., -5.],
|
&[1., -5.],
|
||||||
&[ 2., 5.],
|
&[ 2., 5.],
|
||||||
&[ 3., -2.],
|
&[ 3., -2.],
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
&[4.7, 3.2, 1.3, 0.2],
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
|||||||
@@ -0,0 +1,220 @@
|
|||||||
|
extern crate rand;
|
||||||
|
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
use crate::linalg::Matrix;
|
||||||
|
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct KMeans {
|
||||||
|
k: usize,
|
||||||
|
y: Vec<usize>,
|
||||||
|
size: Vec<usize>,
|
||||||
|
distortion: f64,
|
||||||
|
centroids: Vec<Vec<f64>>
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct KMeansParameters {
|
||||||
|
pub max_iter: usize
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for KMeansParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
KMeansParameters {
|
||||||
|
max_iter: 100
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KMeans{
|
||||||
|
pub fn new<M: Matrix>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans {
|
||||||
|
|
||||||
|
let bbd = BBDTree::new(data);
|
||||||
|
|
||||||
|
if k < 2 {
|
||||||
|
panic!("Invalid number of clusters: {}", k);
|
||||||
|
}
|
||||||
|
|
||||||
|
if parameters.max_iter <= 0 {
|
||||||
|
panic!("Invalid maximum number of iterations: {}", parameters.max_iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (n, d) = data.shape();
|
||||||
|
|
||||||
|
let mut distortion = std::f64::MAX;
|
||||||
|
let mut y = KMeans::kmeans_plus_plus(data, k);
|
||||||
|
let mut size = vec![0; k];
|
||||||
|
let mut centroids = vec![vec![0f64; d]; k];
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
size[y[i]] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
for j in 0..d {
|
||||||
|
centroids[y[i]][j] += data.get(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..k {
|
||||||
|
for j in 0..d {
|
||||||
|
centroids[i][j] /= size[i] as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sums = vec![vec![0f64; d]; k];
|
||||||
|
for _ in 1..= parameters.max_iter {
|
||||||
|
let dist = bbd.clustering(¢roids, &mut sums, &mut size, &mut y);
|
||||||
|
for i in 0..k {
|
||||||
|
if size[i] > 0 {
|
||||||
|
for j in 0..d {
|
||||||
|
centroids[i][j] = sums[i][j] as f64 / size[i] as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if distortion <= dist {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
distortion = dist;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
KMeans{
|
||||||
|
k: k,
|
||||||
|
y: y,
|
||||||
|
size: size,
|
||||||
|
distortion: distortion,
|
||||||
|
centroids: centroids
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector {
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
let mut result = M::zeros(1, n);
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
|
||||||
|
let mut min_dist = std::f64::MAX;
|
||||||
|
let mut best_cluster = 0;
|
||||||
|
|
||||||
|
for j in 0..self.k {
|
||||||
|
let dist = KMeans::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
|
||||||
|
if dist < min_dist {
|
||||||
|
min_dist = dist;
|
||||||
|
best_cluster = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.set(0, i, best_cluster as f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.to_row_vector()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kmeans_plus_plus<M: Matrix>(data: &M, k: usize) -> Vec<usize>{
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let (n, _) = data.shape();
|
||||||
|
let mut y = vec![0; n];
|
||||||
|
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n));
|
||||||
|
|
||||||
|
let mut d = vec![std::f64::MAX; n];
|
||||||
|
|
||||||
|
// pick the next center
|
||||||
|
for j in 1..k {
|
||||||
|
// Loop over the samples and compare them to the most recent center. Store
|
||||||
|
// the distance from each sample to its closest center in scores.
|
||||||
|
for i in 0..n {
|
||||||
|
// compute the distance between this sample and the current center
|
||||||
|
let dist = KMeans::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||||
|
|
||||||
|
if dist < d[i] {
|
||||||
|
d[i] = dist;
|
||||||
|
y[i] = j - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let sum: f64 = d.iter().sum();
|
||||||
|
let cutoff = rng.gen::<f64>() * sum;
|
||||||
|
let mut cost = 0f64;
|
||||||
|
let index = 0;
|
||||||
|
for index in 0..n {
|
||||||
|
cost += d[index];
|
||||||
|
if cost >= cutoff {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
centroid = data.get_row_as_vec(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
// compute the distance between this sample and the current center
|
||||||
|
let dist = KMeans::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||||
|
|
||||||
|
if dist < d[i] {
|
||||||
|
d[i] = dist;
|
||||||
|
y[i] = k - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
y
|
||||||
|
}
|
||||||
|
|
||||||
|
fn squared_distance(x: &Vec<f64>,y: &Vec<f64>) -> f64 {
|
||||||
|
if x.len() != y.len() {
|
||||||
|
panic!("Input vector sizes are different.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sum = 0f64;
|
||||||
|
for i in 0..x.len() {
|
||||||
|
sum += (x[i] - y[i]).powf(2.);
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fit_predict_iris() {
|
||||||
|
let x = DenseMatrix::from_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
|
&[5.2, 2.7, 3.9, 1.4]]);
|
||||||
|
|
||||||
|
let kmeans = KMeans::new(&x, 2, Default::default());
|
||||||
|
|
||||||
|
let y = kmeans.predict(&x);
|
||||||
|
|
||||||
|
for i in 0..y.len() {
|
||||||
|
assert_eq!(y[i] as usize, kmeans.y[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
pub mod kmeans;
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
pub mod classification;
|
pub mod classification;
|
||||||
pub mod regression;
|
pub mod regression;
|
||||||
|
pub mod cluster;
|
||||||
pub mod linalg;
|
pub mod linalg;
|
||||||
pub mod math;
|
pub mod math;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
|||||||
@@ -12,10 +12,6 @@ pub trait Matrix: Clone + Debug {
|
|||||||
|
|
||||||
fn to_row_vector(self) -> Self::RowVector;
|
fn to_row_vector(self) -> Self::RowVector;
|
||||||
|
|
||||||
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self;
|
|
||||||
|
|
||||||
fn from_vec(nrows: usize, ncols: usize, values: Vec<f64>) -> Self;
|
|
||||||
|
|
||||||
fn get(&self, row: usize, col: usize) -> f64;
|
fn get(&self, row: usize, col: usize) -> f64;
|
||||||
|
|
||||||
fn get_row_as_vec(&self, row: usize) -> Vec<f64>;
|
fn get_row_as_vec(&self, row: usize) -> Vec<f64>;
|
||||||
|
|||||||
@@ -14,11 +14,19 @@ pub struct DenseMatrix {
|
|||||||
|
|
||||||
impl DenseMatrix {
|
impl DenseMatrix {
|
||||||
|
|
||||||
pub fn from_2d_array(values: &[&[f64]]) -> DenseMatrix {
|
fn new(nrows: usize, ncols: usize, values: Vec<f64>) -> DenseMatrix {
|
||||||
DenseMatrix::from_2d_vec(&values.into_iter().map(|row| Vec::from(*row)).collect())
|
DenseMatrix {
|
||||||
|
ncols: ncols,
|
||||||
|
nrows: nrows,
|
||||||
|
values: values
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_2d_vec(values: &Vec<Vec<f64>>) -> DenseMatrix {
|
pub fn from_array(values: &[&[f64]]) -> DenseMatrix {
|
||||||
|
DenseMatrix::from_vec(&values.into_iter().map(|row| Vec::from(*row)).collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_vec(values: &Vec<Vec<f64>>) -> DenseMatrix {
|
||||||
let nrows = values.len();
|
let nrows = values.len();
|
||||||
let ncols = values.first().unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector")).len();
|
let ncols = values.first().unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector")).len();
|
||||||
let mut m = DenseMatrix {
|
let mut m = DenseMatrix {
|
||||||
@@ -112,25 +120,13 @@ impl Matrix for DenseMatrix {
|
|||||||
type RowVector = Vec<f64>;
|
type RowVector = Vec<f64>;
|
||||||
|
|
||||||
fn from_row_vector(vec: Self::RowVector) -> Self{
|
fn from_row_vector(vec: Self::RowVector) -> Self{
|
||||||
DenseMatrix::from_vec(1, vec.len(), vec)
|
DenseMatrix::new(1, vec.len(), vec)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_row_vector(self) -> Self::RowVector{
|
fn to_row_vector(self) -> Self::RowVector{
|
||||||
self.to_raw_vector()
|
self.to_raw_vector()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> DenseMatrix {
|
|
||||||
DenseMatrix::from_vec(nrows, ncols, Vec::from(values))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn from_vec(nrows: usize, ncols: usize, values: Vec<f64>) -> DenseMatrix {
|
|
||||||
DenseMatrix {
|
|
||||||
ncols: ncols,
|
|
||||||
nrows: nrows,
|
|
||||||
values: values
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get(&self, row: usize, col: usize) -> f64 {
|
fn get(&self, row: usize, col: usize) -> f64 {
|
||||||
self.values[col*self.nrows + row]
|
self.values[col*self.nrows + row]
|
||||||
}
|
}
|
||||||
@@ -255,7 +251,7 @@ impl Matrix for DenseMatrix {
|
|||||||
let ncols = cols.len();
|
let ncols = cols.len();
|
||||||
let nrows = rows.len();
|
let nrows = rows.len();
|
||||||
|
|
||||||
let mut m = DenseMatrix::from_vec(nrows, ncols, vec![0f64; nrows * ncols]);
|
let mut m = DenseMatrix::new(nrows, ncols, vec![0f64; nrows * ncols]);
|
||||||
|
|
||||||
for r in rows.start..rows.end {
|
for r in rows.start..rows.end {
|
||||||
for c in cols.start..cols.end {
|
for c in cols.start..cols.end {
|
||||||
@@ -731,7 +727,7 @@ impl Matrix for DenseMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn fill(nrows: usize, ncols: usize, value: f64) -> Self {
|
fn fill(nrows: usize, ncols: usize, value: f64) -> Self {
|
||||||
DenseMatrix::from_vec(nrows, ncols, vec![value; ncols * nrows])
|
DenseMatrix::new(nrows, ncols, vec![value; ncols * nrows])
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_mut(&mut self, other: &Self) -> &Self {
|
fn add_mut(&mut self, other: &Self) -> &Self {
|
||||||
@@ -998,7 +994,7 @@ mod tests {
|
|||||||
fn from_to_row_vec() {
|
fn from_to_row_vec() {
|
||||||
|
|
||||||
let vec = vec![ 1., 2., 3.];
|
let vec = vec![ 1., 2., 3.];
|
||||||
assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::from_vec(1, 3, vec![1., 2., 3.]));
|
assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::new(1, 3, vec![1., 2., 3.]));
|
||||||
assert_eq!(DenseMatrix::from_row_vector(vec.clone()).to_row_vector(), vec![1., 2., 3.]);
|
assert_eq!(DenseMatrix::from_row_vector(vec.clone()).to_row_vector(), vec![1., 2., 3.]);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1006,9 +1002,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn qr_solve_mut() {
|
fn qr_solve_mut() {
|
||||||
|
|
||||||
let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let mut a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
|
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
|
||||||
let expected_w = DenseMatrix::from_array(3, 2, &[-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
|
let expected_w = DenseMatrix::new(3, 2, vec![-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
|
||||||
let w = a.qr_solve_mut(b);
|
let w = a.qr_solve_mut(b);
|
||||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
}
|
}
|
||||||
@@ -1016,9 +1012,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn svd_solve_mut() {
|
fn svd_solve_mut() {
|
||||||
|
|
||||||
let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let mut a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
|
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
|
||||||
let expected_w = DenseMatrix::from_array(3, 2, &[-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
|
let expected_w = DenseMatrix::new(3, 2, vec![-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
|
||||||
let w = a.svd_solve_mut(b);
|
let w = a.svd_solve_mut(b);
|
||||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
}
|
}
|
||||||
@@ -1026,16 +1022,16 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn h_stack() {
|
fn h_stack() {
|
||||||
|
|
||||||
let a = DenseMatrix::from_2d_array(
|
let a = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[1., 2., 3.],
|
&[1., 2., 3.],
|
||||||
&[4., 5., 6.],
|
&[4., 5., 6.],
|
||||||
&[7., 8., 9.]]);
|
&[7., 8., 9.]]);
|
||||||
let b = DenseMatrix::from_2d_array(
|
let b = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[1., 2., 3.],
|
&[1., 2., 3.],
|
||||||
&[4., 5., 6.]]);
|
&[4., 5., 6.]]);
|
||||||
let expected = DenseMatrix::from_2d_array(
|
let expected = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[1., 2., 3.],
|
&[1., 2., 3.],
|
||||||
&[4., 5., 6.],
|
&[4., 5., 6.],
|
||||||
@@ -1049,17 +1045,17 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn v_stack() {
|
fn v_stack() {
|
||||||
|
|
||||||
let a = DenseMatrix::from_2d_array(
|
let a = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[1., 2., 3.],
|
&[1., 2., 3.],
|
||||||
&[4., 5., 6.],
|
&[4., 5., 6.],
|
||||||
&[7., 8., 9.]]);
|
&[7., 8., 9.]]);
|
||||||
let b = DenseMatrix::from_2d_array(
|
let b = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[1., 2.],
|
&[1., 2.],
|
||||||
&[3., 4.],
|
&[3., 4.],
|
||||||
&[5., 6.]]);
|
&[5., 6.]]);
|
||||||
let expected = DenseMatrix::from_2d_array(
|
let expected = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[1., 2., 3., 1., 2.],
|
&[1., 2., 3., 1., 2.],
|
||||||
&[4., 5., 6., 3., 4.],
|
&[4., 5., 6., 3., 4.],
|
||||||
@@ -1071,16 +1067,16 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn dot() {
|
fn dot() {
|
||||||
|
|
||||||
let a = DenseMatrix::from_2d_array(
|
let a = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[1., 2., 3.],
|
&[1., 2., 3.],
|
||||||
&[4., 5., 6.]]);
|
&[4., 5., 6.]]);
|
||||||
let b = DenseMatrix::from_2d_array(
|
let b = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[1., 2.],
|
&[1., 2.],
|
||||||
&[3., 4.],
|
&[3., 4.],
|
||||||
&[5., 6.]]);
|
&[5., 6.]]);
|
||||||
let expected = DenseMatrix::from_2d_array(
|
let expected = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[22., 28.],
|
&[22., 28.],
|
||||||
&[49., 64.]]);
|
&[49., 64.]]);
|
||||||
@@ -1091,12 +1087,12 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn slice() {
|
fn slice() {
|
||||||
|
|
||||||
let m = DenseMatrix::from_2d_array(
|
let m = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[1., 2., 3., 1., 2.],
|
&[1., 2., 3., 1., 2.],
|
||||||
&[4., 5., 6., 3., 4.],
|
&[4., 5., 6., 3., 4.],
|
||||||
&[7., 8., 9., 5., 6.]]);
|
&[7., 8., 9., 5., 6.]]);
|
||||||
let expected = DenseMatrix::from_2d_array(
|
let expected = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[2., 3.],
|
&[2., 3.],
|
||||||
&[5., 6.]]);
|
&[5., 6.]]);
|
||||||
@@ -1107,15 +1103,15 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn approximate_eq() {
|
fn approximate_eq() {
|
||||||
let m = DenseMatrix::from_2d_array(
|
let m = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[2., 3.],
|
&[2., 3.],
|
||||||
&[5., 6.]]);
|
&[5., 6.]]);
|
||||||
let m_eq = DenseMatrix::from_2d_array(
|
let m_eq = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[2.5, 3.0],
|
&[2.5, 3.0],
|
||||||
&[5., 5.5]]);
|
&[5., 5.5]]);
|
||||||
let m_neq = DenseMatrix::from_2d_array(
|
let m_neq = DenseMatrix::from_array(
|
||||||
&[
|
&[
|
||||||
&[3.0, 3.0],
|
&[3.0, 3.0],
|
||||||
&[5., 6.5]]);
|
&[5., 6.5]]);
|
||||||
@@ -1135,8 +1131,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn transpose() {
|
fn transpose() {
|
||||||
let m = DenseMatrix::from_2d_array(&[&[1.0, 3.0], &[2.0, 4.0]]);
|
let m = DenseMatrix::from_array(&[&[1.0, 3.0], &[2.0, 4.0]]);
|
||||||
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]);
|
let expected = DenseMatrix::from_array(&[&[1.0, 2.0], &[3.0, 4.0]]);
|
||||||
let m_transposed = m.transpose();
|
let m_transposed = m.transpose();
|
||||||
for c in 0..2 {
|
for c in 0..2 {
|
||||||
for r in 0..2 {
|
for r in 0..2 {
|
||||||
|
|||||||
@@ -16,14 +16,6 @@ impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
|||||||
self.into_shape(vec_size).unwrap()
|
self.into_shape(vec_size).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self {
|
|
||||||
Array::from_shape_vec((nrows, ncols), values.to_vec()).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn from_vec(nrows: usize, ncols: usize, values: Vec<f64>) -> Self {
|
|
||||||
Array::from_shape_vec((nrows, ncols), values).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get(&self, row: usize, col: usize) -> f64 {
|
fn get(&self, row: usize, col: usize) -> f64 {
|
||||||
self[[row, col]]
|
self[[row, col]]
|
||||||
}
|
}
|
||||||
@@ -330,19 +322,6 @@ mod tests {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn from_array_from_vec() {
|
|
||||||
|
|
||||||
let a1 = arr2(&[[ 1., 2., 3.],
|
|
||||||
[4., 5., 6.]]);
|
|
||||||
let a2 = Array2::from_array(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
|
||||||
let a3 = Array2::from_vec(2, 3, vec![1., 2., 3., 4., 5., 6.]);
|
|
||||||
|
|
||||||
assert_eq!(a1, a2);
|
|
||||||
assert_eq!(a1, a3);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn vstack_hstack() {
|
fn vstack_hstack() {
|
||||||
|
|
||||||
|
|||||||
@@ -19,14 +19,14 @@ impl<M: Matrix> LinearRegression<M> {
|
|||||||
|
|
||||||
pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression<M>{
|
pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression<M>{
|
||||||
|
|
||||||
|
let b = y.transpose();
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let (y_nrows, _) = y.shape();
|
let (y_nrows, _) = b.shape();
|
||||||
|
|
||||||
if x_nrows != y_nrows {
|
if x_nrows != y_nrows {
|
||||||
panic!("Number of rows of X doesn't match number of rows of Y");
|
panic!("Number of rows of X doesn't match number of rows of Y");
|
||||||
}
|
}
|
||||||
|
|
||||||
let b = y.clone();
|
|
||||||
let mut a = x.v_stack(&M::ones(x_nrows, 1));
|
let mut a = x.v_stack(&M::ones(x_nrows, 1));
|
||||||
|
|
||||||
let w = match solver {
|
let w = match solver {
|
||||||
@@ -52,7 +52,7 @@ impl<M: Matrix> Regression<M> for LinearRegression<M> {
|
|||||||
let (nrows, _) = x.shape();
|
let (nrows, _) = x.shape();
|
||||||
let mut y_hat = x.dot(&self.coefficients);
|
let mut y_hat = x.dot(&self.coefficients);
|
||||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||||
y_hat
|
y_hat.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -65,7 +65,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn ols_fit_predict() {
|
fn ols_fit_predict() {
|
||||||
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||||
@@ -82,7 +82,8 @@ mod tests {
|
|||||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||||
let y = DenseMatrix::from_array(16, 1, &[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]);
|
|
||||||
|
let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]);
|
||||||
|
|
||||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user