feat: adds KMeans clustering algorithm

This commit is contained in:
Volodymyr Orlov
2020-02-20 18:43:24 -08:00
parent 4359d66bfa
commit 0e89113297
13 changed files with 637 additions and 84 deletions
+39 -43
View File
@@ -12,13 +12,21 @@ pub struct DenseMatrix {
}
impl DenseMatrix {
impl DenseMatrix {
fn new(nrows: usize, ncols: usize, values: Vec<f64>) -> DenseMatrix {
DenseMatrix {
ncols: ncols,
nrows: nrows,
values: values
}
}
pub fn from_2d_array(values: &[&[f64]]) -> DenseMatrix {
DenseMatrix::from_2d_vec(&values.into_iter().map(|row| Vec::from(*row)).collect())
pub fn from_array(values: &[&[f64]]) -> DenseMatrix {
DenseMatrix::from_vec(&values.into_iter().map(|row| Vec::from(*row)).collect())
}
pub fn from_2d_vec(values: &Vec<Vec<f64>>) -> DenseMatrix {
pub fn from_vec(values: &Vec<Vec<f64>>) -> DenseMatrix {
let nrows = values.len();
let ncols = values.first().unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector")).len();
let mut m = DenseMatrix {
@@ -112,24 +120,12 @@ impl Matrix for DenseMatrix {
type RowVector = Vec<f64>;
fn from_row_vector(vec: Self::RowVector) -> Self{
DenseMatrix::from_vec(1, vec.len(), vec)
DenseMatrix::new(1, vec.len(), vec)
}
fn to_row_vector(self) -> Self::RowVector{
self.to_raw_vector()
}
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> DenseMatrix {
DenseMatrix::from_vec(nrows, ncols, Vec::from(values))
}
fn from_vec(nrows: usize, ncols: usize, values: Vec<f64>) -> DenseMatrix {
DenseMatrix {
ncols: ncols,
nrows: nrows,
values: values
}
}
}
fn get(&self, row: usize, col: usize) -> f64 {
self.values[col*self.nrows + row]
@@ -255,7 +251,7 @@ impl Matrix for DenseMatrix {
let ncols = cols.len();
let nrows = rows.len();
let mut m = DenseMatrix::from_vec(nrows, ncols, vec![0f64; nrows * ncols]);
let mut m = DenseMatrix::new(nrows, ncols, vec![0f64; nrows * ncols]);
for r in rows.start..rows.end {
for c in cols.start..cols.end {
@@ -731,7 +727,7 @@ impl Matrix for DenseMatrix {
}
fn fill(nrows: usize, ncols: usize, value: f64) -> Self {
DenseMatrix::from_vec(nrows, ncols, vec![value; ncols * nrows])
DenseMatrix::new(nrows, ncols, vec![value; ncols * nrows])
}
fn add_mut(&mut self, other: &Self) -> &Self {
@@ -998,7 +994,7 @@ mod tests {
fn from_to_row_vec() {
let vec = vec![ 1., 2., 3.];
assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::from_vec(1, 3, vec![1., 2., 3.]));
assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::new(1, 3, vec![1., 2., 3.]));
assert_eq!(DenseMatrix::from_row_vector(vec.clone()).to_row_vector(), vec![1., 2., 3.]);
}
@@ -1006,9 +1002,9 @@ mod tests {
#[test]
fn qr_solve_mut() {
let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = DenseMatrix::from_array(3, 2, &[-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
let mut a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = DenseMatrix::new(3, 2, vec![-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
let w = a.qr_solve_mut(b);
assert!(w.approximate_eq(&expected_w, 1e-2));
}
@@ -1016,9 +1012,9 @@ mod tests {
#[test]
fn svd_solve_mut() {
let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = DenseMatrix::from_array(3, 2, &[-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
let mut a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = DenseMatrix::new(3, 2, vec![-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
let w = a.svd_solve_mut(b);
assert!(w.approximate_eq(&expected_w, 1e-2));
}
@@ -1026,16 +1022,16 @@ mod tests {
#[test]
fn h_stack() {
let a = DenseMatrix::from_2d_array(
let a = DenseMatrix::from_array(
&[
&[1., 2., 3.],
&[4., 5., 6.],
&[7., 8., 9.]]);
let b = DenseMatrix::from_2d_array(
let b = DenseMatrix::from_array(
&[
&[1., 2., 3.],
&[4., 5., 6.]]);
let expected = DenseMatrix::from_2d_array(
let expected = DenseMatrix::from_array(
&[
&[1., 2., 3.],
&[4., 5., 6.],
@@ -1049,17 +1045,17 @@ mod tests {
#[test]
fn v_stack() {
let a = DenseMatrix::from_2d_array(
let a = DenseMatrix::from_array(
&[
&[1., 2., 3.],
&[4., 5., 6.],
&[7., 8., 9.]]);
let b = DenseMatrix::from_2d_array(
let b = DenseMatrix::from_array(
&[
&[1., 2.],
&[3., 4.],
&[5., 6.]]);
let expected = DenseMatrix::from_2d_array(
let expected = DenseMatrix::from_array(
&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
@@ -1071,16 +1067,16 @@ mod tests {
#[test]
fn dot() {
let a = DenseMatrix::from_2d_array(
let a = DenseMatrix::from_array(
&[
&[1., 2., 3.],
&[4., 5., 6.]]);
let b = DenseMatrix::from_2d_array(
let b = DenseMatrix::from_array(
&[
&[1., 2.],
&[3., 4.],
&[5., 6.]]);
let expected = DenseMatrix::from_2d_array(
let expected = DenseMatrix::from_array(
&[
&[22., 28.],
&[49., 64.]]);
@@ -1091,12 +1087,12 @@ mod tests {
#[test]
fn slice() {
let m = DenseMatrix::from_2d_array(
let m = DenseMatrix::from_array(
&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.]]);
let expected = DenseMatrix::from_2d_array(
let expected = DenseMatrix::from_array(
&[
&[2., 3.],
&[5., 6.]]);
@@ -1107,15 +1103,15 @@ mod tests {
#[test]
fn approximate_eq() {
let m = DenseMatrix::from_2d_array(
let m = DenseMatrix::from_array(
&[
&[2., 3.],
&[5., 6.]]);
let m_eq = DenseMatrix::from_2d_array(
let m_eq = DenseMatrix::from_array(
&[
&[2.5, 3.0],
&[5., 5.5]]);
let m_neq = DenseMatrix::from_2d_array(
let m_neq = DenseMatrix::from_array(
&[
&[3.0, 3.0],
&[5., 6.5]]);
@@ -1135,8 +1131,8 @@ mod tests {
#[test]
fn transpose() {
let m = DenseMatrix::from_2d_array(&[&[1.0, 3.0], &[2.0, 4.0]]);
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]);
let m = DenseMatrix::from_array(&[&[1.0, 3.0], &[2.0, 4.0]]);
let expected = DenseMatrix::from_array(&[&[1.0, 2.0], &[3.0, 4.0]]);
let m_transposed = m.transpose();
for c in 0..2 {
for r in 0..2 {