fix: cargo fmt
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
pub mod neighbour;
|
||||
pub mod sort;
|
||||
pub mod neighbour;
|
||||
@@ -1,18 +1,18 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::*;
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BBDTree<T: FloatExt> {
|
||||
pub struct BBDTree<T: FloatExt> {
|
||||
nodes: Vec<BBDTreeNode<T>>,
|
||||
index: Vec<usize>,
|
||||
root: usize
|
||||
root: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BBDTreeNode<T: FloatExt> {
|
||||
struct BBDTreeNode<T: FloatExt> {
|
||||
count: usize,
|
||||
index: usize,
|
||||
center: Vec<T>,
|
||||
@@ -20,7 +20,7 @@ struct BBDTreeNode<T: FloatExt> {
|
||||
sum: Vec<T>,
|
||||
cost: T,
|
||||
lower: Option<usize>,
|
||||
upper: Option<usize>
|
||||
upper: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T: FloatExt> BBDTreeNode<T> {
|
||||
@@ -33,7 +33,7 @@ impl<T: FloatExt> BBDTreeNode<T> {
|
||||
sum: vec![T::zero(); d],
|
||||
cost: T::zero(),
|
||||
lower: Option::None,
|
||||
upper: Option::None
|
||||
upper: Option::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -49,10 +49,10 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
index[i] = i;
|
||||
}
|
||||
|
||||
let mut tree = BBDTree{
|
||||
let mut tree = BBDTree {
|
||||
nodes: nodes,
|
||||
index: index,
|
||||
root: 0
|
||||
root: 0,
|
||||
};
|
||||
|
||||
let root = tree.build_node(data, 0, n);
|
||||
@@ -60,29 +60,54 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
tree.root = root;
|
||||
|
||||
tree
|
||||
}
|
||||
}
|
||||
|
||||
pub(in crate) fn clustering(&self, centroids: &Vec<Vec<T>>, sums: &mut Vec<Vec<T>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> T {
|
||||
pub(in crate) fn clustering(
|
||||
&self,
|
||||
centroids: &Vec<Vec<T>>,
|
||||
sums: &mut Vec<Vec<T>>,
|
||||
counts: &mut Vec<usize>,
|
||||
membership: &mut Vec<usize>,
|
||||
) -> T {
|
||||
let k = centroids.len();
|
||||
|
||||
counts.iter_mut().for_each(|x| *x = 0);
|
||||
let mut candidates = vec![0; k];
|
||||
for i in 0..k {
|
||||
candidates[i] = i;
|
||||
sums[i].iter_mut().for_each(|x| *x = T::zero());
|
||||
sums[i].iter_mut().for_each(|x| *x = T::zero());
|
||||
}
|
||||
|
||||
self.filter(self.root, centroids, &candidates, k, sums, counts, membership)
|
||||
|
||||
self.filter(
|
||||
self.root,
|
||||
centroids,
|
||||
&candidates,
|
||||
k,
|
||||
sums,
|
||||
counts,
|
||||
membership,
|
||||
)
|
||||
}
|
||||
|
||||
fn filter(&self, node: usize, centroids: &Vec<Vec<T>>, candidates: &Vec<usize>, k: usize, sums: &mut Vec<Vec<T>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> T{
|
||||
fn filter(
|
||||
&self,
|
||||
node: usize,
|
||||
centroids: &Vec<Vec<T>>,
|
||||
candidates: &Vec<usize>,
|
||||
k: usize,
|
||||
sums: &mut Vec<Vec<T>>,
|
||||
counts: &mut Vec<usize>,
|
||||
membership: &mut Vec<usize>,
|
||||
) -> T {
|
||||
let d = centroids[0].len();
|
||||
|
||||
// Determine which mean the node mean is closest to
|
||||
let mut min_dist = Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||
let mut min_dist =
|
||||
Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||
let mut closest = candidates[0];
|
||||
for i in 1..k {
|
||||
let dist = Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||
let dist =
|
||||
Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
closest = candidates[i];
|
||||
@@ -92,11 +117,17 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
// If this is a non-leaf node, recurse if necessary
|
||||
if !self.nodes[node].lower.is_none() {
|
||||
// Build the new list of candidates
|
||||
let mut new_candidates = vec![0;k];
|
||||
let mut new_candidates = vec![0; k];
|
||||
let mut newk = 0;
|
||||
|
||||
for i in 0..k {
|
||||
if !BBDTree::prune(&self.nodes[node].center, &self.nodes[node].radius, ¢roids, closest, candidates[i]) {
|
||||
if !BBDTree::prune(
|
||||
&self.nodes[node].center,
|
||||
&self.nodes[node].radius,
|
||||
¢roids,
|
||||
closest,
|
||||
candidates[i],
|
||||
) {
|
||||
new_candidates[newk] = candidates[i];
|
||||
newk += 1;
|
||||
}
|
||||
@@ -104,8 +135,23 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
|
||||
// Recurse if there's at least two
|
||||
if newk > 1 {
|
||||
let result = self.filter(self.nodes[node].lower.unwrap(), centroids, &mut new_candidates, newk, sums, counts, membership) +
|
||||
self.filter(self.nodes[node].upper.unwrap(), centroids, &mut new_candidates, newk, sums, counts, membership);
|
||||
let result = self.filter(
|
||||
self.nodes[node].lower.unwrap(),
|
||||
centroids,
|
||||
&mut new_candidates,
|
||||
newk,
|
||||
sums,
|
||||
counts,
|
||||
membership,
|
||||
) + self.filter(
|
||||
self.nodes[node].upper.unwrap(),
|
||||
centroids,
|
||||
&mut new_candidates,
|
||||
newk,
|
||||
sums,
|
||||
counts,
|
||||
membership,
|
||||
);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -116,17 +162,22 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
}
|
||||
|
||||
counts[closest] += self.nodes[node].count;
|
||||
|
||||
|
||||
let last = self.nodes[node].index + self.nodes[node].count;
|
||||
for i in self.nodes[node].index..last {
|
||||
membership[self.index[i]] = closest;
|
||||
}
|
||||
}
|
||||
|
||||
BBDTree::node_cost(&self.nodes[node], ¢roids[closest])
|
||||
|
||||
}
|
||||
|
||||
fn prune(center: &Vec<T>, radius: &Vec<T>, centroids: &Vec<Vec<T>>, best_index: usize, test_index: usize) -> bool {
|
||||
fn prune(
|
||||
center: &Vec<T>,
|
||||
radius: &Vec<T>,
|
||||
centroids: &Vec<Vec<T>>,
|
||||
best_index: usize,
|
||||
test_index: usize,
|
||||
) -> bool {
|
||||
if best_index == test_index {
|
||||
return false;
|
||||
}
|
||||
@@ -148,7 +199,7 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
}
|
||||
|
||||
return lhs >= T::two() * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
fn build_node<M: Matrix<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
|
||||
let (_, d) = data.shape();
|
||||
@@ -165,8 +216,8 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
let mut upper_bound = vec![T::zero(); d];
|
||||
|
||||
for i in 0..d {
|
||||
lower_bound[i] = data.get(self.index[begin],i);
|
||||
upper_bound[i] = data.get(self.index[begin],i);
|
||||
lower_bound[i] = data.get(self.index[begin], i);
|
||||
upper_bound[i] = data.get(self.index[begin], i);
|
||||
}
|
||||
|
||||
for i in begin..end {
|
||||
@@ -200,7 +251,7 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
for i in 0..d {
|
||||
node.sum[i] = data.get(self.index[begin], i);
|
||||
}
|
||||
|
||||
|
||||
if end > begin + 1 {
|
||||
let len = end - begin;
|
||||
for i in 0..d {
|
||||
@@ -247,7 +298,8 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
|
||||
// Calculate the new sum and opt cost
|
||||
for i in 0..d {
|
||||
node.sum[i] = self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
|
||||
node.sum[i] =
|
||||
self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
|
||||
}
|
||||
|
||||
let mut mean = vec![T::zero(); d];
|
||||
@@ -255,7 +307,8 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
mean[i] = node.sum[i] / T::from(node.count).unwrap();
|
||||
}
|
||||
|
||||
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean) + BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
|
||||
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean)
|
||||
+ BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
|
||||
|
||||
self.add_node(node)
|
||||
}
|
||||
@@ -270,7 +323,7 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
node.cost + T::from(node.count).unwrap() * scatter
|
||||
}
|
||||
|
||||
fn add_node(&mut self, new_node: BBDTreeNode<T>) -> usize{
|
||||
fn add_node(&mut self, new_node: BBDTreeNode<T>) -> usize {
|
||||
let idx = self.nodes.len();
|
||||
self.nodes.push(new_node);
|
||||
idx
|
||||
@@ -279,12 +332,11 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
|
||||
fn fit_predict_iris() {
|
||||
let data = DenseMatrix::from_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
@@ -305,30 +357,23 @@ mod tests {
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4]]);
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let tree = BBDTree::new(&data);
|
||||
let tree = BBDTree::new(&data);
|
||||
|
||||
let centroids = vec![
|
||||
vec![4.86, 3.22, 1.61, 0.29],
|
||||
vec![6.23, 2.92, 4.48, 1.42]
|
||||
];
|
||||
let centroids = vec![vec![4.86, 3.22, 1.61, 0.29], vec![6.23, 2.92, 4.48, 1.42]];
|
||||
|
||||
let mut sums = vec![
|
||||
vec![0f64; 4],
|
||||
vec![0f64; 4]
|
||||
];
|
||||
let mut sums = vec![vec![0f64; 4], vec![0f64; 4]];
|
||||
|
||||
let mut counts = vec![11, 9];
|
||||
|
||||
let mut membership = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1];
|
||||
|
||||
let dist = tree.clustering(¢roids, &mut sums, &mut counts, &mut membership);
|
||||
let dist = tree.clustering(¢roids, &mut sums, &mut counts, &mut membership);
|
||||
assert!((dist - 10.68).abs() < 1e-2);
|
||||
assert!((sums[0][0] - 48.6).abs() < 1e-2);
|
||||
assert!((sums[1][3] - 13.8).abs() < 1e-2);
|
||||
assert_eq!(membership[17], 1);
|
||||
|
||||
assert_eq!(membership[17], 1);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,116 +1,126 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::iter::FromIterator;
|
||||
use std::fmt::Debug;
|
||||
use core::hash::{Hash, Hasher};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fmt::Debug;
|
||||
use std::iter::FromIterator;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct CoverTree<T, F: FloatExt, D: Distance<T, F>>
|
||||
{
|
||||
pub struct CoverTree<T, F: FloatExt, D: Distance<T, F>> {
|
||||
base: F,
|
||||
max_level: i8,
|
||||
min_level: i8,
|
||||
distance: D,
|
||||
nodes: Vec<Node<T>>
|
||||
nodes: Vec<Node<T>>,
|
||||
}
|
||||
|
||||
impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
{
|
||||
|
||||
impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D> {
|
||||
pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
|
||||
let mut tree = CoverTree {
|
||||
base: F::two(),
|
||||
max_level: 100,
|
||||
min_level: 100,
|
||||
distance: distance,
|
||||
nodes: Vec::new()
|
||||
nodes: Vec::new(),
|
||||
};
|
||||
|
||||
let p = tree.new_node(None, data.remove(0));
|
||||
let p = tree.new_node(None, data.remove(0));
|
||||
tree.construct(p, data, Vec::new(), 10);
|
||||
|
||||
tree
|
||||
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, p: T) {
|
||||
if self.nodes.is_empty(){
|
||||
if self.nodes.is_empty() {
|
||||
self.new_node(None, p);
|
||||
} else {
|
||||
} else {
|
||||
let mut parent: Option<NodeId> = Option::None;
|
||||
let mut p_i = 0;
|
||||
let mut qi_p_ds = vec!((self.root(), self.distance.distance(&p, &self.root().data)));
|
||||
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
|
||||
let mut i = self.max_level;
|
||||
loop {
|
||||
loop {
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_q = self.min_by_distance(&q_p_ds);
|
||||
let d_p_q = self.min_by_distance(&q_p_ds);
|
||||
if d_p_q < F::epsilon() {
|
||||
return
|
||||
return;
|
||||
} else if d_p_q > i_d {
|
||||
break;
|
||||
}
|
||||
if self.min_by_distance(&qi_p_ds) <= self.base.powf(F::from(i).unwrap()){
|
||||
}
|
||||
if self.min_by_distance(&qi_p_ds) <= self.base.powf(F::from(i).unwrap()) {
|
||||
parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, _)| n.index);
|
||||
p_i = i;
|
||||
}
|
||||
|
||||
|
||||
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &i_d).collect();
|
||||
i -= 1;
|
||||
i -= 1;
|
||||
}
|
||||
|
||||
let new_node = self.new_node(parent, p);
|
||||
|
||||
let new_node = self.new_node(parent, p);
|
||||
self.add_child(parent.unwrap(), new_node, p_i);
|
||||
self.min_level = i8::min(self.min_level, p_i-1);
|
||||
self.min_level = i8::min(self.min_level, p_i - 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_node(&mut self, parent: Option<NodeId>, data: T) -> NodeId {
|
||||
let next_index = self.nodes.len();
|
||||
let node_id = NodeId { index: next_index };
|
||||
self.nodes.push(
|
||||
Node {
|
||||
index: node_id,
|
||||
data: data,
|
||||
parent: parent,
|
||||
children: HashMap::new()
|
||||
});
|
||||
self.nodes.push(Node {
|
||||
index: node_id,
|
||||
data: data,
|
||||
parent: parent,
|
||||
children: HashMap::new(),
|
||||
});
|
||||
node_id
|
||||
}
|
||||
|
||||
pub fn find(&self, p: &T, k: usize) -> Vec<usize>{
|
||||
let mut qi_p_ds = vec!((self.root(), self.distance.distance(&p, &self.root().data)));
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
||||
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
|
||||
}
|
||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
|
||||
}
|
||||
|
||||
fn split(&self, p_id: NodeId, r: F, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){
|
||||
|
||||
let mut my_near = (Vec::new(), Vec::new());
|
||||
pub fn find(&self, p: &T, k: usize) -> Vec<usize> {
|
||||
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
|
||||
for i in (self.min_level..self.max_level + 1).rev() {
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
||||
qi_p_ds = q_p_ds
|
||||
.into_iter()
|
||||
.filter(|(_, d)| d <= &(d_p_q + i_d))
|
||||
.collect();
|
||||
}
|
||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)]
|
||||
.iter()
|
||||
.map(|(n, _)| n.index.index)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn split(
|
||||
&self,
|
||||
p_id: NodeId,
|
||||
r: F,
|
||||
s1: &mut Vec<T>,
|
||||
s2: Option<&mut Vec<T>>,
|
||||
) -> (Vec<T>, Vec<T>) {
|
||||
let mut my_near = (Vec::new(), Vec::new());
|
||||
|
||||
my_near = self.split_remove_s(p_id, r, s1, my_near);
|
||||
|
||||
for s in s2 {
|
||||
my_near = self.split_remove_s(p_id, r, s, my_near);
|
||||
my_near = self.split_remove_s(p_id, r, s, my_near);
|
||||
}
|
||||
|
||||
return my_near
|
||||
|
||||
return my_near;
|
||||
}
|
||||
|
||||
fn split_remove_s(&self, p_id: NodeId, r: F, s: &mut Vec<T>, mut my_near: (Vec<T>, Vec<T>)) -> (Vec<T>, Vec<T>){
|
||||
|
||||
fn split_remove_s(
|
||||
&self,
|
||||
p_id: NodeId,
|
||||
r: F,
|
||||
s: &mut Vec<T>,
|
||||
mut my_near: (Vec<T>, Vec<T>),
|
||||
) -> (Vec<T>, Vec<T>) {
|
||||
if s.len() > 0 {
|
||||
let p = &self.nodes.get(p_id.index).unwrap().data;
|
||||
let mut i = 0;
|
||||
@@ -118,61 +128,84 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
let d = self.distance.distance(p, &s[i]);
|
||||
if d <= r {
|
||||
my_near.0.push(s.remove(i));
|
||||
} else if d > r && d <= F::two() * r{
|
||||
my_near.1.push(s.remove(i));
|
||||
} else if d > r && d <= F::two() * r {
|
||||
my_near.1.push(s.remove(i));
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return my_near
|
||||
}
|
||||
return my_near;
|
||||
}
|
||||
|
||||
fn construct<'b>(&mut self, p: NodeId, mut near: Vec<T>, mut far: Vec<T>, i: i8) -> (NodeId, Vec<T>) {
|
||||
|
||||
if near.len() < 1{
|
||||
self.min_level = std::cmp::min(self.min_level, i);
|
||||
return (p, far);
|
||||
fn construct<'b>(
|
||||
&mut self,
|
||||
p: NodeId,
|
||||
mut near: Vec<T>,
|
||||
mut far: Vec<T>,
|
||||
i: i8,
|
||||
) -> (NodeId, Vec<T>) {
|
||||
if near.len() < 1 {
|
||||
self.min_level = std::cmp::min(self.min_level, i);
|
||||
return (p, far);
|
||||
} else {
|
||||
let (my, n) = self.split(p, self.base.powf(F::from(i-1).unwrap()), &mut near, None);
|
||||
let (pi, mut near) = self.construct(p, my, n, i-1);
|
||||
let (my, n) = self.split(p, self.base.powf(F::from(i - 1).unwrap()), &mut near, None);
|
||||
let (pi, mut near) = self.construct(p, my, n, i - 1);
|
||||
while near.len() > 0 {
|
||||
let q_data = near.remove(0);
|
||||
let nn = self.new_node(Some(p), q_data);
|
||||
let (my, n) = self.split(nn, self.base.powf(F::from(i-1).unwrap()), &mut near, Some(&mut far));
|
||||
let (child, mut unused) = self.construct(nn, my, n, i-1);
|
||||
let q_data = near.remove(0);
|
||||
let nn = self.new_node(Some(p), q_data);
|
||||
let (my, n) = self.split(
|
||||
nn,
|
||||
self.base.powf(F::from(i - 1).unwrap()),
|
||||
&mut near,
|
||||
Some(&mut far),
|
||||
);
|
||||
let (child, mut unused) = self.construct(nn, my, n, i - 1);
|
||||
self.add_child(pi, child, i);
|
||||
let new_near_far = self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None);
|
||||
let new_near_far =
|
||||
self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None);
|
||||
near.extend(new_near_far.0);
|
||||
far.extend(new_near_far.1);
|
||||
}
|
||||
self.min_level = std::cmp::min(self.min_level, i);
|
||||
return (pi, far);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fn add_child(&mut self, parent: NodeId, node: NodeId, i: i8){
|
||||
self.nodes.get_mut(parent.index).unwrap().children.insert(i, node);
|
||||
fn add_child(&mut self, parent: NodeId, node: NodeId, i: i8) {
|
||||
self.nodes
|
||||
.get_mut(parent.index)
|
||||
.unwrap()
|
||||
.children
|
||||
.insert(i, node);
|
||||
}
|
||||
|
||||
fn root(&self) -> &Node<T> {
|
||||
self.nodes.first().unwrap()
|
||||
}
|
||||
|
||||
fn get_children_dist<'b>(&'b self, p: &T, qi_p_ds: &Vec<(&'b Node<T>, F)>, i: i8) -> Vec<(&'b Node<T>, F)> {
|
||||
|
||||
fn get_children_dist<'b>(
|
||||
&'b self,
|
||||
p: &T,
|
||||
qi_p_ds: &Vec<(&'b Node<T>, F)>,
|
||||
i: i8,
|
||||
) -> Vec<(&'b Node<T>, F)> {
|
||||
let mut children = Vec::<(&'b Node<T>, F)>::new();
|
||||
|
||||
children.extend(qi_p_ds.iter().cloned());
|
||||
|
||||
let q: Vec<&Node<T>> = qi_p_ds.iter().flat_map(|(n, _)| self.get_child(n, i)).collect();
|
||||
let q: Vec<&Node<T>> = qi_p_ds
|
||||
.iter()
|
||||
.flat_map(|(n, _)| self.get_child(n, i))
|
||||
.collect();
|
||||
|
||||
children.extend(q.into_iter().map(|n| (n, self.distance.distance(&n.data, &p))));
|
||||
children.extend(
|
||||
q.into_iter()
|
||||
.map(|n| (n, self.distance.distance(&n.data, &p))),
|
||||
);
|
||||
|
||||
children
|
||||
|
||||
}
|
||||
|
||||
fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, F)>, k: usize) -> F {
|
||||
@@ -185,18 +218,27 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
}
|
||||
|
||||
fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, F)>) -> F {
|
||||
q_p_ds.into_iter().min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()).unwrap().1
|
||||
q_p_ds
|
||||
.into_iter()
|
||||
.min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap())
|
||||
.unwrap()
|
||||
.1
|
||||
}
|
||||
|
||||
fn get_child(&self, node: &Node<T>, i: i8) -> Option<&Node<T>> {
|
||||
node.children.get(&i).and_then(|n_id| self.nodes.get(n_id.index))
|
||||
}
|
||||
node.children
|
||||
.get(&i)
|
||||
.and_then(|n_id| self.nodes.get(n_id.index))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn check_invariant(&self, invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
|
||||
fn check_invariant(
|
||||
&self,
|
||||
invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> (),
|
||||
) {
|
||||
let mut current_nodes: Vec<&Node<T>> = Vec::new();
|
||||
current_nodes.push(self.root());
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
for i in (self.min_level..self.max_level + 1).rev() {
|
||||
let mut next_nodes: Vec<&Node<T>> = Vec::new();
|
||||
next_nodes.extend(current_nodes.iter());
|
||||
next_nodes.extend(current_nodes.iter().flat_map(|n| self.get_child(n, i)));
|
||||
@@ -206,39 +248,55 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn nesting_invariant(_: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
|
||||
fn nesting_invariant(
|
||||
_: &CoverTree<T, F, D>,
|
||||
nodes: &Vec<&Node<T>>,
|
||||
next_nodes: &Vec<&Node<T>>,
|
||||
_: i8,
|
||||
) {
|
||||
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
|
||||
let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
|
||||
let next_nodes_set: HashSet<&Node<T>> =
|
||||
HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
|
||||
for n in nodes_set.iter() {
|
||||
assert!(next_nodes_set.contains(n), "Nesting invariant of the cover tree is not satisfied. Set of nodes [{:?}] is not a subset of [{:?}]", nodes_set, next_nodes_set);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn covering_tree(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
|
||||
fn covering_tree(
|
||||
tree: &CoverTree<T, F, D>,
|
||||
nodes: &Vec<&Node<T>>,
|
||||
next_nodes: &Vec<&Node<T>>,
|
||||
i: i8,
|
||||
) {
|
||||
let mut p_selected: Vec<&Node<T>> = Vec::new();
|
||||
for p in next_nodes {
|
||||
for q in nodes {
|
||||
if tree.distance.distance(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
|
||||
p_selected.push(*p);
|
||||
}
|
||||
}
|
||||
let c = p_selected.iter().filter(|q| p.parent.map(|p| q.index == p).unwrap_or(false)).count();
|
||||
}
|
||||
let c = p_selected
|
||||
.iter()
|
||||
.filter(|q| p.parent.map(|p| q.index == p).unwrap_or(false))
|
||||
.count();
|
||||
assert!(c <= 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn separation(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
||||
fn separation(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
||||
for p in nodes {
|
||||
for q in nodes {
|
||||
if p != q {
|
||||
assert!(tree.distance.distance(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
tree.distance.distance(&p.data, &q.data)
|
||||
> tree.base.powf(F::from(i).unwrap())
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
@@ -251,7 +309,7 @@ struct Node<T> {
|
||||
index: NodeId,
|
||||
data: T,
|
||||
children: HashMap<i8, NodeId>,
|
||||
parent: Option<NodeId>
|
||||
parent: Option<NodeId>,
|
||||
}
|
||||
|
||||
impl<T> PartialEq for Node<T> {
|
||||
@@ -277,22 +335,22 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
struct SimpleDistance{}
|
||||
struct SimpleDistance {}
|
||||
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
fn distance(&self, a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cover_tree_test() {
|
||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
||||
|
||||
let mut tree = CoverTree::new(data, SimpleDistance{});
|
||||
for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) {
|
||||
fn cover_tree_test() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
|
||||
let mut tree = CoverTree::new(data, SimpleDistance {});
|
||||
for d in vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19] {
|
||||
tree.insert(d);
|
||||
}
|
||||
}
|
||||
|
||||
let mut nearest_3_to_5 = tree.find(&5, 3);
|
||||
nearest_3_to_5.sort();
|
||||
@@ -307,13 +365,12 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invariants(){
|
||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
||||
|
||||
let tree = CoverTree::new(data, SimpleDistance{});
|
||||
fn test_invariants() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
|
||||
let tree = CoverTree::new(data, SimpleDistance {});
|
||||
tree.check_invariant(CoverTree::nesting_invariant);
|
||||
tree.check_invariant(CoverTree::covering_tree);
|
||||
tree.check_invariant(CoverTree::separation);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,53 +1,52 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::{Ordering, PartialOrd};
|
||||
use std::marker::PhantomData;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LinearKNNSearch<T, F: FloatExt, D: Distance<T, F>> {
|
||||
distance: D,
|
||||
data: Vec<T>,
|
||||
f: PhantomData<F>
|
||||
f: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D>{
|
||||
LinearKNNSearch{
|
||||
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D> {
|
||||
LinearKNNSearch {
|
||||
data: data,
|
||||
distance: distance,
|
||||
f: PhantomData
|
||||
f: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
}
|
||||
|
||||
let mut heap = HeapSelect::<KNNPoint<F>>::with_capacity(k);
|
||||
}
|
||||
|
||||
let mut heap = HeapSelect::<KNNPoint<F>>::with_capacity(k);
|
||||
|
||||
for _ in 0..k {
|
||||
heap.add(KNNPoint{
|
||||
heap.add(KNNPoint {
|
||||
distance: F::infinity(),
|
||||
index: None
|
||||
index: None,
|
||||
});
|
||||
}
|
||||
|
||||
for i in 0..self.data.len() {
|
||||
|
||||
let d = self.distance.distance(&from, &self.data[i]);
|
||||
let datum = heap.peek_mut();
|
||||
let datum = heap.peek_mut();
|
||||
if d < datum.distance {
|
||||
datum.distance = d;
|
||||
datum.index = Some(i);
|
||||
heap.heapify();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
heap.sort();
|
||||
heap.sort();
|
||||
|
||||
heap.get().into_iter().flat_map(|x| x.index).collect()
|
||||
}
|
||||
@@ -56,7 +55,7 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
#[derive(Debug)]
|
||||
struct KNNPoint<F: FloatExt> {
|
||||
distance: F,
|
||||
index: Option<usize>
|
||||
index: Option<usize>,
|
||||
}
|
||||
|
||||
impl<F: FloatExt> PartialOrd for KNNPoint<F> {
|
||||
@@ -74,27 +73,33 @@ impl<F: FloatExt> PartialEq for KNNPoint<F> {
|
||||
impl<F: FloatExt> Eq for KNNPoint<F> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
struct SimpleDistance{}
|
||||
struct SimpleDistance {}
|
||||
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
fn distance(&self, a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn knn_find() {
|
||||
let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||
fn knn_find() {
|
||||
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||
|
||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance{});
|
||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
|
||||
|
||||
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
||||
|
||||
let data2 = vec!(vec![1., 1.], vec![2., 2.], vec![3., 3.], vec![4., 4.], vec![5., 5.]);
|
||||
let data2 = vec![
|
||||
vec![1., 1.],
|
||||
vec![2., 2.],
|
||||
vec![3., 3.],
|
||||
vec![4., 4.],
|
||||
vec![5., 5.],
|
||||
];
|
||||
|
||||
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
||||
|
||||
@@ -103,29 +108,29 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn knn_point_eq() {
|
||||
let point1 = KNNPoint{
|
||||
let point1 = KNNPoint {
|
||||
distance: 10.,
|
||||
index: Some(0)
|
||||
index: Some(0),
|
||||
};
|
||||
|
||||
let point2 = KNNPoint{
|
||||
let point2 = KNNPoint {
|
||||
distance: 100.,
|
||||
index: Some(1)
|
||||
index: Some(1),
|
||||
};
|
||||
|
||||
let point3 = KNNPoint{
|
||||
let point3 = KNNPoint {
|
||||
distance: 10.,
|
||||
index: Some(2)
|
||||
index: Some(2),
|
||||
};
|
||||
|
||||
let point_inf = KNNPoint{
|
||||
let point_inf = KNNPoint {
|
||||
distance: std::f64::INFINITY,
|
||||
index: Some(3)
|
||||
index: Some(3),
|
||||
};
|
||||
|
||||
assert!(point2 > point1);
|
||||
assert_eq!(point3, point1);
|
||||
assert_ne!(point3, point2);
|
||||
assert!(point_inf > point3 && point_inf > point2 && point_inf > point1);
|
||||
assert!(point_inf > point3 && point_inf > point2 && point_inf > point1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
pub mod bbd_tree;
|
||||
pub mod cover_tree;
|
||||
pub mod linear_search;
|
||||
pub mod bbd_tree;
|
||||
@@ -5,21 +5,20 @@ pub struct HeapSelect<T: PartialOrd> {
|
||||
k: usize,
|
||||
n: usize,
|
||||
sorted: bool,
|
||||
heap: Vec<T>
|
||||
heap: Vec<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: PartialOrd> HeapSelect<T> {
|
||||
|
||||
pub fn with_capacity(k: usize) -> HeapSelect<T> {
|
||||
HeapSelect{
|
||||
HeapSelect {
|
||||
k: k,
|
||||
n: 0,
|
||||
sorted: false,
|
||||
heap: Vec::<T>::new()
|
||||
heap: Vec::<T>::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&mut self, element: T) {
|
||||
pub fn add(&mut self, element: T) {
|
||||
self.sorted = false;
|
||||
if self.n < self.k {
|
||||
self.heap.push(element);
|
||||
@@ -30,23 +29,23 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
|
||||
} else {
|
||||
self.n += 1;
|
||||
if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) {
|
||||
self.heap[0] = element;
|
||||
self.heap[0] = element;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn heapify(&mut self) {
|
||||
pub fn heapify(&mut self) {
|
||||
let n = self.heap.len();
|
||||
for i in (0..=(n / 2 - 1)).rev() {
|
||||
self.sift_down(i, n-1);
|
||||
self.sift_down(i, n - 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn peek(&self) -> &T {
|
||||
pub fn peek(&self) -> &T {
|
||||
return &self.heap[0];
|
||||
}
|
||||
|
||||
pub fn peek_mut(&mut self) -> &mut T {
|
||||
pub fn peek_mut(&mut self) -> &mut T {
|
||||
return &mut self.heap[0];
|
||||
}
|
||||
|
||||
@@ -59,11 +58,10 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
|
||||
}
|
||||
if self.heap[k] >= self.heap[j] {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.heap.swap(k, j);
|
||||
k = j;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub fn get(self) -> Vec<T> {
|
||||
@@ -71,7 +69,7 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
|
||||
}
|
||||
|
||||
pub fn sort(&mut self) {
|
||||
HeapSelect::shuffle_sort(&mut self.heap, std::cmp::min(self.k,self.n));
|
||||
HeapSelect::shuffle_sort(&mut self.heap, std::cmp::min(self.k, self.n));
|
||||
}
|
||||
|
||||
pub fn shuffle_sort(vec: &mut Vec<T>, n: usize) {
|
||||
@@ -80,10 +78,10 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
|
||||
inc *= 3;
|
||||
inc += 1
|
||||
}
|
||||
|
||||
|
||||
let len = n;
|
||||
while inc >= 1 {
|
||||
let mut i = inc;
|
||||
let mut i = inc;
|
||||
while i < len {
|
||||
let mut j = i;
|
||||
while j >= inc && vec[j - inc] > vec[j] {
|
||||
@@ -95,60 +93,58 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
|
||||
inc /= 3
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn with_capacity() {
|
||||
let heap = HeapSelect::<i32>::with_capacity(3);
|
||||
assert_eq!(3, heap.k);
|
||||
fn with_capacity() {
|
||||
let heap = HeapSelect::<i32>::with_capacity(3);
|
||||
assert_eq!(3, heap.k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add() {
|
||||
let mut heap = HeapSelect::with_capacity(3);
|
||||
fn test_add() {
|
||||
let mut heap = HeapSelect::with_capacity(3);
|
||||
heap.add(333);
|
||||
heap.add(2);
|
||||
heap.add(13);
|
||||
heap.add(10);
|
||||
heap.add(40);
|
||||
heap.add(30);
|
||||
heap.add(30);
|
||||
assert_eq!(6, heap.n);
|
||||
assert_eq!(&10, heap.peek());
|
||||
assert_eq!(&10, heap.peek_mut());
|
||||
assert_eq!(&10, heap.peek());
|
||||
assert_eq!(&10, heap.peek_mut());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_ordered() {
|
||||
let mut heap = HeapSelect::with_capacity(3);
|
||||
fn test_add_ordered() {
|
||||
let mut heap = HeapSelect::with_capacity(3);
|
||||
heap.add(1.);
|
||||
heap.add(2.);
|
||||
heap.add(3.);
|
||||
heap.add(4.);
|
||||
heap.add(5.);
|
||||
heap.add(6.);
|
||||
heap.add(6.);
|
||||
let result = heap.get();
|
||||
assert_eq!(vec![2., 3., 1.], result);
|
||||
assert_eq!(vec![2., 3., 1.], result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle_sort() {
|
||||
fn test_shuffle_sort() {
|
||||
let mut v1 = vec![10, 33, 22, 105, 12];
|
||||
let n = v1.len();
|
||||
HeapSelect::shuffle_sort(&mut v1, n);
|
||||
assert_eq!(vec![10, 12, 22, 33, 105], v1);
|
||||
|
||||
let mut v2 = vec![10, 33, 22, 105, 12];
|
||||
let mut v2 = vec![10, 33, 22, 105, 12];
|
||||
HeapSelect::shuffle_sort(&mut v2, 3);
|
||||
assert_eq!(vec![10, 22, 33, 105, 12], v2);
|
||||
|
||||
let mut v3 = vec![4, 5, 3, 2, 1];
|
||||
let mut v3 = vec![4, 5, 3, 2, 1];
|
||||
HeapSelect::shuffle_sort(&mut v3, 3);
|
||||
assert_eq!(vec![3, 4, 5, 2, 1], v3);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
pub mod heap_select;
|
||||
pub mod quick_sort;
|
||||
pub mod quick_sort;
|
||||
|
||||
@@ -5,13 +5,12 @@ pub trait QuickArgSort {
|
||||
}
|
||||
|
||||
impl<T: Float> QuickArgSort for Vec<T> {
|
||||
|
||||
fn quick_argsort(&mut self) -> Vec<usize> {
|
||||
let stack_size = 64;
|
||||
let mut jstack = -1;
|
||||
let mut l = 0;
|
||||
let mut istack = vec![0; stack_size];
|
||||
let mut ir = self.len() - 1;
|
||||
let mut ir = self.len() - 1;
|
||||
let mut index: Vec<usize> = (0..self.len()).collect();
|
||||
|
||||
loop {
|
||||
@@ -19,21 +18,21 @@ impl<T: Float> QuickArgSort for Vec<T> {
|
||||
for j in l + 1..=ir {
|
||||
let a = self[j];
|
||||
let b = index[j];
|
||||
let mut i: i32 = (j - 1) as i32;
|
||||
while i >= l as i32 {
|
||||
if self[i as usize] <= a {
|
||||
let mut i: i32 = (j - 1) as i32;
|
||||
while i >= l as i32 {
|
||||
if self[i as usize] <= a {
|
||||
break;
|
||||
}
|
||||
self[(i + 1) as usize] = self[i as usize];
|
||||
index[(i + 1) as usize] = index[i as usize];
|
||||
i -= 1;
|
||||
}
|
||||
i -= 1;
|
||||
}
|
||||
self[(i + 1) as usize] = a;
|
||||
index[(i + 1) as usize] = b;
|
||||
}
|
||||
if jstack < 0 {
|
||||
index[(i + 1) as usize] = b;
|
||||
}
|
||||
if jstack < 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
ir = istack[jstack as usize];
|
||||
jstack -= 1;
|
||||
l = istack[jstack as usize];
|
||||
@@ -66,7 +65,7 @@ impl<T: Float> QuickArgSort for Vec<T> {
|
||||
}
|
||||
}
|
||||
loop {
|
||||
j -=1;
|
||||
j -= 1;
|
||||
if self[j] <= a {
|
||||
break;
|
||||
}
|
||||
@@ -81,7 +80,7 @@ impl<T: Float> QuickArgSort for Vec<T> {
|
||||
self[j] = a;
|
||||
index[l + 1] = index[j];
|
||||
index[j] = b;
|
||||
jstack += 2;
|
||||
jstack += 2;
|
||||
|
||||
if jstack >= 64 {
|
||||
panic!("stack size is too small.");
|
||||
@@ -95,7 +94,7 @@ impl<T: Float> QuickArgSort for Vec<T> {
|
||||
istack[jstack as usize] = j - 1;
|
||||
istack[jstack as usize - 1] = l;
|
||||
l = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,15 +103,21 @@ impl<T: Float> QuickArgSort for Vec<T> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn with_capacity() {
|
||||
let mut arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
|
||||
assert_eq!(vec![1, 2, 0, 3, 5, 7, 6, 8, 4], arr1.quick_argsort());
|
||||
fn with_capacity() {
|
||||
let mut arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
|
||||
assert_eq!(vec![1, 2, 0, 3, 5, 7, 6, 8, 4], arr1.quick_argsort());
|
||||
|
||||
let mut arr2 = vec![0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6, 1.0, 1.3, 1.4];
|
||||
assert_eq!(vec![9, 7, 1, 8, 0, 2, 4, 3, 6, 5, 17, 18, 15, 13, 19, 10, 14, 11, 12, 16], arr2.quick_argsort());
|
||||
let mut arr2 = vec![
|
||||
0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6,
|
||||
1.0, 1.3, 1.4,
|
||||
];
|
||||
assert_eq!(
|
||||
vec![9, 7, 1, 8, 0, 2, 4, 3, 6, 5, 17, 18, 15, 13, 19, 10, 14, 11, 12, 16],
|
||||
arr2.quick_argsort()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user