Commit e03410e8 authored by Julien Jerphanion's avatar Julien Jerphanion

Add a test

parent 06debbda
import numpy as np
import kdtree
if __name__ == "__main__":
np.random.seed(1)
n, d = 1000000, 2
golden_ratio = (1 + 5 ** 0.5) / 2
X = np.zeros((n, d))
for i in range(n):
X[i, 0] = (i / golden_ratio) % 1
X[i, 1] = i / n
tree = kdtree.KDTree(X, depth=2)
query_points = np.random.rand(100000, 2)
closests = np.zeros((query_points.shape[0]), dtype=np.int32)
tree.get_closest(query_points, closests)
\ No newline at end of file
import numpy as np
import pytest
import kdtree
from sklearn.neighbors import KDTree
@pytest.mark.parametrize("n", [10, 100, 1000, 10000])
@pytest.mark.parametrize("d", [10, 100])
@pytest.mark.parametrize("leaf_size", [256, 1024])
def test_against_sklearn(n, d, leaf_size):
np.random.seed(1)
X = np.random.rand(n, d)
query_points = np.random.rand(n, d)
tree = kdtree.KDTree(X, leaf_size=256)
skl_tree = KDTree(X, leaf_size=256)
closests = np.zeros((n), dtype=np.int32)
tree.get_closest(query_points, closests)
skl_closests = skl_tree.query(query_points, return_distance=False)
# The back tracking part of the algorithm is not yet implemented
# hence, we test for a almost equality
assert (np.ndarray.flatten(closests) == np.ndarray.flatten(skl_closests)).mean() > 0.9
\ No newline at end of file
......@@ -3,4 +3,5 @@ jupyter==1.0.0
line-profiler==3.1.0
matplotlib==3.4.1
numpy==1.20.2
pytest
-e git+https://lab.nexedi.com/nexedi/cython.git@fd3a224472d75f7c6107828c1e4b9587f3990a46#egg=Cython
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment