feat: refactoring, adds Result to most public API

This commit is contained in:
Volodymyr Orlov
2020-09-18 15:20:32 -07:00
parent 4921ae76f5
commit a9db970195
24 changed files with 389 additions and 298 deletions
+16 -12
View File
@@ -16,7 +16,7 @@
//!
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
//!
//! let mut tree = CoverTree::new(data, SimpleDistance {});
//! let mut tree = CoverTree::new(data, SimpleDistance {}).unwrap();
//!
//! tree.find(&5, 3); // find 3 knn points from 5
//!
@@ -26,6 +26,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::algorithm::sort::heap_select::HeapSelection;
use crate::error::{Failed, FailedError};
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
@@ -73,7 +74,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
/// Construct a cover tree.
/// * `data` - vector of data points to search for.
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
pub fn new(data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, F, D>, Failed> {
let base = F::from_f64(1.3).unwrap();
let root = Node {
idx: 0,
@@ -93,19 +94,22 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
tree.build_cover_tree();
tree
Ok(tree)
}
/// Find k nearest neighbors of `p`
/// * `p` - look for k nearest points to `p`
/// * `k` - the number of nearest neighbors to return
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> {
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
if k <= 0 {
panic!("k should be > 0");
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
}
if k > self.data.len() {
panic!("k is > than the dataset size");
return Err(Failed::because(
FailedError::FindFailed,
"k is > than the dataset size",
));
}
let e = self.get_data_value(self.root.idx);
@@ -171,7 +175,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
}
}
neighbors.into_iter().take(k).collect()
Ok(neighbors.into_iter().take(k).collect())
}
fn new_leaf(&self, idx: usize) -> Node<F> {
@@ -407,9 +411,9 @@ mod tests {
fn cover_tree_test() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let tree = CoverTree::new(data, SimpleDistance {});
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
let mut knn = tree.find(&5, 3);
let mut knn = tree.find(&5, 3).unwrap();
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
assert_eq!(vec!(3, 4, 5), knn);
@@ -425,9 +429,9 @@ mod tests {
vec![9., 10.],
];
let tree = CoverTree::new(data, Distances::euclidian());
let tree = CoverTree::new(data, Distances::euclidian()).unwrap();
let mut knn = tree.find(&vec![1., 2.], 3);
let mut knn = tree.find(&vec![1., 2.], 3).unwrap();
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
@@ -438,7 +442,7 @@ mod tests {
fn serde() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let tree = CoverTree::new(data, SimpleDistance {});
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();
+18 -10
View File
@@ -15,7 +15,7 @@
//!
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
//!
//! let knn = LinearKNNSearch::new(data, SimpleDistance {});
//! let knn = LinearKNNSearch::new(data, SimpleDistance {}).unwrap();
//!
//! knn.find(&5, 3); // find 3 knn points from 5
//!
@@ -26,6 +26,7 @@ use std::cmp::{Ordering, PartialOrd};
use std::marker::PhantomData;
use crate::algorithm::sort::heap_select::HeapSelection;
use crate::error::Failed;
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
@@ -41,18 +42,18 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
/// Initializes algorithm.
/// * `data` - vector of data points to search for.
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D> {
LinearKNNSearch {
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, F, D>, Failed> {
Ok(LinearKNNSearch {
data: data,
distance: distance,
f: PhantomData,
}
})
}
/// Find k nearest neighbors
/// * `from` - look for k nearest points to `from`
/// * `k` - the number of nearest neighbors to return
pub fn find(&self, from: &T, k: usize) -> Vec<(usize, F)> {
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
if k < 1 || k > self.data.len() {
panic!("k should be >= 1 and <= length(data)");
}
@@ -76,10 +77,11 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
}
}
heap.get()
Ok(heap
.get()
.into_iter()
.flat_map(|x| x.index.map(|i| (i, x.distance)))
.collect()
.collect())
}
}
@@ -120,9 +122,14 @@ mod tests {
fn knn_find() {
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {}).unwrap();
let mut found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
let mut found_idxs1: Vec<usize> = algorithm1
.find(&2, 3)
.unwrap()
.iter()
.map(|v| v.0)
.collect();
found_idxs1.sort();
assert_eq!(vec!(0, 1, 2), found_idxs1);
@@ -135,10 +142,11 @@ mod tests {
vec![5., 5.],
];
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian()).unwrap();
let mut found_idxs2: Vec<usize> = algorithm2
.find(&vec![3., 3.], 3)
.unwrap()
.iter()
.map(|v| v.0)
.collect();