Commit e8d90ccd authored by Julien Jerphanion's avatar Julien Jerphanion

Refactor

Simplify logic.
Add comments to explain motives.
parent 280c0904
This diff is collapsed.
...@@ -16,6 +16,7 @@ if __name__ == '__main__': ...@@ -16,6 +16,7 @@ if __name__ == '__main__':
tree = kdtree.KDTree(X, leaf_size=256) tree = kdtree.KDTree(X, leaf_size=256)
closests = np.zeros((n_query, k), dtype=np.int32) closests = np.zeros((n_query, k), dtype=np.int32)
# tree.query(query_points, closests)
# There's currently a deadlock here # skl_tree = KDTree(X, leaf_size=256)
tree.query(query_points, closests) # skl_closests = skl_tree.query(query_points, k=k, return_distance=False).astype(np.int32)
\ No newline at end of file \ No newline at end of file
...@@ -5,8 +5,9 @@ from sklearn.neighbors import KDTree ...@@ -5,8 +5,9 @@ from sklearn.neighbors import KDTree
@pytest.mark.parametrize("n", [10, 100, 1000, 10000]) @pytest.mark.parametrize("n", [10, 100, 1000, 10000])
@pytest.mark.parametrize("d", [10, 100]) @pytest.mark.parametrize("d", [10, 100])
@pytest.mark.parametrize("k", [1, 2, 5, 10])
@pytest.mark.parametrize("leaf_size", [256, 1024]) @pytest.mark.parametrize("leaf_size", [256, 1024])
def test_against_sklearn(n, d, leaf_size): def test_against_sklearn(n, d, k, leaf_size):
np.random.seed(1) np.random.seed(1)
X = np.random.rand(n, d) X = np.random.rand(n, d)
query_points = np.random.rand(n, d) query_points = np.random.rand(n, d)
...@@ -14,10 +15,10 @@ def test_against_sklearn(n, d, leaf_size): ...@@ -14,10 +15,10 @@ def test_against_sklearn(n, d, leaf_size):
tree = kdtree.KDTree(X, leaf_size=256) tree = kdtree.KDTree(X, leaf_size=256)
skl_tree = KDTree(X, leaf_size=256) skl_tree = KDTree(X, leaf_size=256)
closests = np.zeros((n, 2), dtype=np.int32) closests = np.zeros((n, k), dtype=np.int32)
tree.query(query_points, closests) tree.query(query_points, closests)
skl_closests = skl_tree.query(query_points, return_distance=False) skl_closests = skl_tree.query(query_points, k=k, return_distance=False).astype(np.int32)
# The back tracking part of the algorithm is not yet implemented # The back tracking part of the algorithm is not yet implemented
# hence, we test for a almost equality # hence, we test for a almost equality
assert (np.ndarray.flatten(closests) == np.ndarray.flatten(skl_closests)).mean() > 0.9 np.testing.assert_equal(closests, skl_closests)
\ No newline at end of file \ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment