feat: refactoring, adds Result to most public API
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
//!
|
||||
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
|
||||
//!
|
||||
//! let mut tree = CoverTree::new(data, SimpleDistance {});
|
||||
//! let mut tree = CoverTree::new(data, SimpleDistance {}).unwrap();
|
||||
//!
|
||||
//! tree.find(&5, 3); // find 3 knn points from 5
|
||||
//!
|
||||
@@ -26,6 +26,7 @@ use std::fmt::Debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -73,7 +74,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
/// Construct a cover tree.
|
||||
/// * `data` - vector of data points to search for.
|
||||
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
||||
pub fn new(data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
|
||||
pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, F, D>, Failed> {
|
||||
let base = F::from_f64(1.3).unwrap();
|
||||
let root = Node {
|
||||
idx: 0,
|
||||
@@ -93,19 +94,22 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
|
||||
tree.build_cover_tree();
|
||||
|
||||
tree
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Find k nearest neighbors of `p`
|
||||
/// * `p` - look for k nearest points to `p`
|
||||
/// * `k` - the number of nearest neighbors to return
|
||||
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> {
|
||||
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
|
||||
if k <= 0 {
|
||||
panic!("k should be > 0");
|
||||
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
|
||||
}
|
||||
|
||||
if k > self.data.len() {
|
||||
panic!("k is > than the dataset size");
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"k is > than the dataset size",
|
||||
));
|
||||
}
|
||||
|
||||
let e = self.get_data_value(self.root.idx);
|
||||
@@ -171,7 +175,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
}
|
||||
}
|
||||
|
||||
neighbors.into_iter().take(k).collect()
|
||||
Ok(neighbors.into_iter().take(k).collect())
|
||||
}
|
||||
|
||||
fn new_leaf(&self, idx: usize) -> Node<F> {
|
||||
@@ -407,9 +411,9 @@ mod tests {
|
||||
fn cover_tree_test() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
|
||||
let tree = CoverTree::new(data, SimpleDistance {});
|
||||
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
|
||||
|
||||
let mut knn = tree.find(&5, 3);
|
||||
let mut knn = tree.find(&5, 3).unwrap();
|
||||
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
|
||||
assert_eq!(vec!(3, 4, 5), knn);
|
||||
@@ -425,9 +429,9 @@ mod tests {
|
||||
vec![9., 10.],
|
||||
];
|
||||
|
||||
let tree = CoverTree::new(data, Distances::euclidian());
|
||||
let tree = CoverTree::new(data, Distances::euclidian()).unwrap();
|
||||
|
||||
let mut knn = tree.find(&vec![1., 2.], 3);
|
||||
let mut knn = tree.find(&vec![1., 2.], 3).unwrap();
|
||||
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
|
||||
|
||||
@@ -438,7 +442,7 @@ mod tests {
|
||||
fn serde() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
|
||||
let tree = CoverTree::new(data, SimpleDistance {});
|
||||
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
|
||||
|
||||
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
|
||||
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
//!
|
||||
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
|
||||
//!
|
||||
//! let knn = LinearKNNSearch::new(data, SimpleDistance {});
|
||||
//! let knn = LinearKNNSearch::new(data, SimpleDistance {}).unwrap();
|
||||
//!
|
||||
//! knn.find(&5, 3); // find 3 knn points from 5
|
||||
//!
|
||||
@@ -26,6 +26,7 @@ use std::cmp::{Ordering, PartialOrd};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||
use crate::error::Failed;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -41,18 +42,18 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
/// Initializes algorithm.
|
||||
/// * `data` - vector of data points to search for.
|
||||
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
||||
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D> {
|
||||
LinearKNNSearch {
|
||||
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, F, D>, Failed> {
|
||||
Ok(LinearKNNSearch {
|
||||
data: data,
|
||||
distance: distance,
|
||||
f: PhantomData,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Find k nearest neighbors
|
||||
/// * `from` - look for k nearest points to `from`
|
||||
/// * `k` - the number of nearest neighbors to return
|
||||
pub fn find(&self, from: &T, k: usize) -> Vec<(usize, F)> {
|
||||
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
}
|
||||
@@ -76,10 +77,11 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
heap.get()
|
||||
Ok(heap
|
||||
.get()
|
||||
.into_iter()
|
||||
.flat_map(|x| x.index.map(|i| (i, x.distance)))
|
||||
.collect()
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,9 +122,14 @@ mod tests {
|
||||
fn knn_find() {
|
||||
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||
|
||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
|
||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {}).unwrap();
|
||||
|
||||
let mut found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
|
||||
let mut found_idxs1: Vec<usize> = algorithm1
|
||||
.find(&2, 3)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|v| v.0)
|
||||
.collect();
|
||||
found_idxs1.sort();
|
||||
|
||||
assert_eq!(vec!(0, 1, 2), found_idxs1);
|
||||
@@ -135,10 +142,11 @@ mod tests {
|
||||
vec![5., 5.],
|
||||
];
|
||||
|
||||
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
||||
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian()).unwrap();
|
||||
|
||||
let mut found_idxs2: Vec<usize> = algorithm2
|
||||
.find(&vec![3., 3.], 3)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|v| v.0)
|
||||
.collect();
|
||||
|
||||
Reference in New Issue
Block a user