Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython_plus_experiments
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Julien Jerphanion
cython_plus_experiments
Commits
e8d90ccd
Commit
e8d90ccd
authored
Jun 11, 2021
by
Julien Jerphanion
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor
Simplify logic. Add comments to explain motives.
parent
280c0904
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
197 additions
and
152 deletions
+197
-152
kdtree/kdtree.pyx
kdtree/kdtree.pyx
+189
-146
kdtree/query_poc.py
kdtree/query_poc.py
+3
-2
kdtree/tests/test_conf.py
kdtree/tests/test_conf.py
+5
-4
No files found.
kdtree/kdtree.pyx
View file @
e8d90ccd
...
@@ -18,6 +18,8 @@ ctypedef double D_t
...
@@ -18,6 +18,8 @@ ctypedef double D_t
cdef
lock
Scheduler
scheduler
cdef
lock
Scheduler
scheduler
cdef
D_t
INF
=
1e38
cdef
D_t
INF
=
1e38
# NOTE: The following extern definition is used to interface
# std::nth_element, a robust partitioning algorithm, in Cython
cdef
extern
from
*
:
cdef
extern
from
*
:
"""
"""
#include <algorithm>
#include <algorithm>
...
@@ -66,109 +68,13 @@ cdef extern from *:
...
@@ -66,109 +68,13 @@ cdef extern from *:
I
n_features
I
n_features
)
nogil
except
+
)
nogil
except
+
cdef
cypclass
Node
activable
:
"""A KDTree Node"""
D_t
*
_data_ptr
I_t
*
_indices_ptr
D_t
*
_point
I_t
_n_dims
I_t
_dim
I_t
_start
I_t
_end
bint
_is_leaf
active
Node
_left
active
Node
_right
__init__
(
self
):
self
.
_active_result_class
=
WaitResult
.
construct
self
.
_active_queue_class
=
consume
BatchMailBox
(
scheduler
)
self
.
_left
=
NULL
self
.
_right
=
NULL
self
.
_is_leaf
=
False
void
build_node
(
self
,
D_t
*
data_ptr
,
I_t
*
indices_ptr
,
I_t
leaf_size
,
I_t
n_dims
,
I_t
dim
,
I_t
start
,
I_t
end
,
active
Counter
counter
,
):
cdef
I_t
i
cdef
I_t
next_dim
=
(
dim
+
1
)
%
n_dims
cdef
I_t
mid
=
(
start
+
end
)
//
2
self
.
_data_ptr
=
data_ptr
self
.
_indices_ptr
=
indices_ptr
self
.
_dim
=
dim
self
.
_n_dims
=
n_dims
self
.
_start
=
start
self
.
_end
=
end
if
(
end
-
start
<=
leaf_size
):
self
.
_is_leaf
=
True
counter
.
add
(
NULL
,
end
-
start
)
return
partition_node_indices
(
data_ptr
,
indices_ptr
,
start
,
mid
,
end
,
dim
,
n_dims
)
self
.
_point
=
data_ptr
+
mid
self
.
_left
=
consume
Node
()
self
.
_right
=
consume
Node
()
self
.
_left
.
build_node
(
NULL
,
data_ptr
,
indices_ptr
,
leaf_size
,
n_dims
,
next_dim
,
start
,
mid
,
counter
)
self
.
_right
.
build_node
(
NULL
,
data_ptr
,
indices_ptr
,
leaf_size
,
n_dims
,
next_dim
,
mid
,
end
,
counter
)
void
query
(
self
,
D_t
*
query_points
,
I_t
i
,
active
NeighborsHeap
heaps
,
):
cdef
:
I_t
j
,
k
,
closest
=
-
1
D_t
dist
=
INF
D_t
tmp
D_t
min_distance
=
INF
if
self
.
_is_leaf
:
for
j
in
range
(
self
.
_start
,
self
.
_end
):
dist
=
0
for
k
in
range
(
self
.
_n_dims
):
tmp
=
(
query_points
[
i
*
self
.
_n_dims
+
k
]
-
self
.
_data_ptr
[
self
.
_indices_ptr
[
j
]
*
self
.
_n_dims
+
k
]
)
dist
+=
tmp
*
tmp
heaps
.
push
(
NULL
,
i
,
dist
,
j
)
if
dist
<
min_distance
:
closest
=
j
min_distance
=
dist
return
# if query_points[ i * self._n_dims + self._dim] < self._point[self._dim]:
self
.
_left
.
query
(
NULL
,
query_points
,
i
,
heaps
)
# else:
self
.
_right
.
query
(
NULL
,
query_points
,
i
,
heaps
)
cdef
cypclass
Counter
activable
:
cdef
cypclass
Counter
activable
:
""" A simple Counter.
""" A simple Counter.
Useful for synchronisation, as it can be used as a barrier.
This can be useful for synchronisation for the caller after
triggering the actors logic as it wait for the value of
the Coutner to reach a given one before moving on.
"""
"""
I_t
_n
I_t
_n
...
@@ -191,7 +97,11 @@ cdef cypclass Counter activable:
...
@@ -191,7 +97,11 @@ cdef cypclass Counter activable:
cdef
cypclass
Container
activable
:
cdef
cypclass
Container
activable
:
""" A simple wrapper of an array.
""" A simple wrapper of an array.
Useful for synchronisation, as it can be used as a barrier.
The wrapped array is passed by the initial caller which then
trigger the actors logic modifying it.
The initial caller can wait for specific number of update
before proceeding.
"""
"""
I_t
*
_array
I_t
*
_array
...
@@ -210,9 +120,10 @@ cdef cypclass Container activable:
...
@@ -210,9 +120,10 @@ cdef cypclass Container activable:
self
.
_n_updates
+=
1
self
.
_n_updates
+=
1
self
.
_array
[
idx
]
=
value
self
.
_array
[
idx
]
=
value
int
get_
n_updates
(
self
):
int
n_updates
(
self
):
return
self
.
_n_updates
return
self
.
_n_updates
cdef
inline
void
dual_swap
(
D_t
*
darr
,
I_t
*
iarr
,
I_t
i1
,
I_t
i2
)
nogil
:
cdef
inline
void
dual_swap
(
D_t
*
darr
,
I_t
*
iarr
,
I_t
i1
,
I_t
i2
)
nogil
:
"""swap the values at inex i1 and i2 of both darr and iarr"""
"""swap the values at inex i1 and i2 of both darr and iarr"""
cdef
D_t
dtmp
=
darr
[
i1
]
cdef
D_t
dtmp
=
darr
[
i1
]
...
@@ -288,24 +199,6 @@ cdef void _simultaneous_sort(
...
@@ -288,24 +199,6 @@ cdef void _simultaneous_sort(
size
-
pivot_idx
-
1
)
size
-
pivot_idx
-
1
)
cdef
void
_sort
(
D_t
*
dist
,
I_t
*
idx
,
I_t
n_rows
,
I_t
size
,
)
nogil
:
"""simultaneously sort the distances and indices"""
cdef
I_t
row
for
row
in
prange
(
n_rows
,
nogil
=
True
,
schedule
=
"static"
,
num_threads
=
4
):
_simultaneous_sort
(
dist
+
row
*
size
,
idx
+
row
*
size
,
size
)
cdef
cypclass
NeighborsHeap
activable
:
cdef
cypclass
NeighborsHeap
activable
:
"""A max-heap structure to keep track of distances/indices of neighbors
"""A max-heap structure to keep track of distances/indices of neighbors
...
@@ -317,6 +210,12 @@ cdef cypclass NeighborsHeap activable:
...
@@ -317,6 +210,12 @@ cdef cypclass NeighborsHeap activable:
Taken and adapted from:
Taken and adapted from:
https://github.com/scikit-learn/scikit-learn/blob/e4bb9fa86b0df873ad750b6d59090843d9d23d50/sklearn/neighbors/_binary_tree.pxi#L513
https://github.com/scikit-learn/scikit-learn/blob/e4bb9fa86b0df873ad750b6d59090843d9d23d50/sklearn/neighbors/_binary_tree.pxi#L513
The initial caller is responsible for providing the array of indices
which will be modified in place by the actors logic.
n_pushes and is_sorted can be used by the initial caller to know
when to pursue.
Parameters
Parameters
----------
----------
n_pts : int
n_pts : int
...
@@ -333,6 +232,7 @@ cdef cypclass NeighborsHeap activable:
...
@@ -333,6 +232,7 @@ cdef cypclass NeighborsHeap activable:
bint
_sorted
bint
_sorted
__init__
(
self
,
I_t
*
indices
,
I_t
n_pts
,
I_t
n_nbrs
):
__init__
(
self
,
I_t
*
indices
,
I_t
n_pts
,
I_t
n_nbrs
):
cdef
I_t
i
self
.
_n_pts
=
n_pts
self
.
_n_pts
=
n_pts
self
.
_n_nbrs
=
n_nbrs
self
.
_n_nbrs
=
n_nbrs
self
.
_active_result_class
=
WaitResult
.
construct
self
.
_active_result_class
=
WaitResult
.
construct
...
@@ -342,22 +242,31 @@ cdef cypclass NeighborsHeap activable:
...
@@ -342,22 +242,31 @@ cdef cypclass NeighborsHeap activable:
self
.
_n_pushes
=
0
self
.
_n_pushes
=
0
self
.
_sorted
=
False
self
.
_sorted
=
False
# We can't use memset here
for
i
in
range
(
n_pts
*
n_nbrs
):
self
.
_distances
[
i
]
=
INF
void
__dealloc__
(
self
):
void
__dealloc__
(
self
):
free
(
self
.
_distances
)
free
(
self
.
_distances
)
void
push
(
self
,
I_t
row
,
D_t
val
,
I_t
i_val
):
void
push
(
self
,
I_t
row
,
D_t
val
,
I_t
i_val
):
"""push (val, i_val) into the given row"""
"""push (val, i_val) into the given row"""
cdef
I_t
i
,
left_child_idx
,
right_child_idx
,
swap_idx
cdef
:
I_t
i
,
left_child_idx
,
right_child_idx
,
swap_idx
# Getting the heap to use
I_t
*
indices
=
self
.
_indices
+
row
*
self
.
_n_nbrs
D_t
*
distances
=
self
.
_distances
+
row
*
self
.
_n_nbrs
self
.
_n_pushes
+=
1
self
.
_n_pushes
+=
1
# check if val should be in heap
# check if val should be in heap
if
val
>
self
.
_
distances
[
0
]:
if
val
>
distances
[
0
]:
return
return
# insert val at position zero
# insert val at position zero
self
.
_
distances
[
0
]
=
val
distances
[
0
]
=
val
self
.
_
indices
[
0
]
=
i_val
indices
[
0
]
=
i_val
# descend the heap, swapping values until the max heap criterion is met
# descend the heap, swapping values until the max heap criterion is met
i
=
0
i
=
0
...
@@ -368,42 +277,176 @@ cdef cypclass NeighborsHeap activable:
...
@@ -368,42 +277,176 @@ cdef cypclass NeighborsHeap activable:
if
left_child_idx
>=
self
.
_n_nbrs
:
if
left_child_idx
>=
self
.
_n_nbrs
:
break
break
elif
right_child_idx
>=
self
.
_n_nbrs
:
elif
right_child_idx
>=
self
.
_n_nbrs
:
if
self
.
_
distances
[
left_child_idx
]
>
val
:
if
distances
[
left_child_idx
]
>
val
:
swap_idx
=
left_child_idx
swap_idx
=
left_child_idx
else
:
else
:
break
break
elif
self
.
_distances
[
left_child_idx
]
>=
self
.
_
distances
[
right_child_idx
]:
elif
distances
[
left_child_idx
]
>=
distances
[
right_child_idx
]:
if
val
<
self
.
_
distances
[
left_child_idx
]:
if
val
<
distances
[
left_child_idx
]:
swap_idx
=
left_child_idx
swap_idx
=
left_child_idx
else
:
else
:
break
break
else
:
else
:
if
val
<
self
.
_
distances
[
right_child_idx
]:
if
val
<
distances
[
right_child_idx
]:
swap_idx
=
right_child_idx
swap_idx
=
right_child_idx
else
:
else
:
break
break
self
.
_distances
[
i
]
=
self
.
_
distances
[
swap_idx
]
distances
[
i
]
=
distances
[
swap_idx
]
self
.
_indices
[
i
]
=
self
.
_
indices
[
swap_idx
]
indices
[
i
]
=
indices
[
swap_idx
]
i
=
swap_idx
i
=
swap_idx
self
.
_
distances
[
i
]
=
val
distances
[
i
]
=
val
self
.
_
indices
[
i
]
=
i_val
indices
[
i
]
=
i_val
void
sort
(
self
):
void
sort
(
self
):
_sort
(
self
.
_distances
,
self
.
_indices
,
# NOTE: Ideally we could sort results in parallel, but
self
.
_n_pts
,
self
.
_n_nbrs
)
# OpenMP threadpool and this runtime's aren't working
self
.
_sorted
=
False
# nicely together (using prange here would create OpenMP
# threads within workers, causing unexpected behavior.)
cdef
I_t
row
for
row
in
range
(
self
.
_n_pts
):
# We use a function here to be able to recurse
_simultaneous_sort
(
self
.
_distances
+
row
*
self
.
_n_nbrs
,
self
.
_indices
+
row
*
self
.
_n_nbrs
,
self
.
_n_nbrs
)
self
.
_sorted
=
True
int
n_pushes
(
self
):
int
n_pushes
(
self
):
return
self
.
_n_pushes
return
self
.
_n_pushes
int
is_sorted
(
self
):
int
is_sorted
(
self
):
# TODO: is there a support for returning bool via
# promises? As of now returning a int for making
# use of getIntResult
return
1
if
self
.
_sorted
else
0
return
1
if
self
.
_sorted
else
0
cdef
cypclass
Node
activable
:
"""A KDTree Node
Node delegate tasks to their children Nodes.
Some Nodes are set as Leaves when they are associated
to ``leaf_size`` or less points.
Leafs are terminal Nodes and do not have children.
"""
# Reference to the head of the allocated arrays
# data gets not modified via _data_ptr
D_t
*
_data_ptr
I_t
*
_indices_ptr
# The point the Node split on
D_t
*
_point
I_t
_n_dims
I_t
_dim
# Portion of _indices covered by the Node is:
# _indices_ptr[_start:_end]
I_t
_start
I_t
_end
bint
_is_leaf
active
Node
_left
active
Node
_right
__init__
(
self
):
self
.
_active_result_class
=
WaitResult
.
construct
self
.
_active_queue_class
=
consume
BatchMailBox
(
scheduler
)
self
.
_left
=
NULL
self
.
_right
=
NULL
self
.
_is_leaf
=
False
# We use this to allow using actors for initialisation
# because __init__ can't be reified.
void
build_node
(
self
,
D_t
*
data_ptr
,
I_t
*
indices_ptr
,
I_t
leaf_size
,
I_t
n_dims
,
I_t
dim
,
I_t
start
,
I_t
end
,
active
Counter
counter
,
):
cdef
I_t
i
cdef
I_t
next_dim
=
(
dim
+
1
)
%
n_dims
cdef
I_t
mid
=
(
start
+
end
)
//
2
self
.
_data_ptr
=
data_ptr
self
.
_indices_ptr
=
indices_ptr
self
.
_dim
=
dim
self
.
_n_dims
=
n_dims
self
.
_start
=
start
self
.
_end
=
end
if
(
end
-
start
<=
leaf_size
):
self
.
_is_leaf
=
True
# Adding to the global counter the number
# of samples the leaf is responsible of
counter
.
add
(
NULL
,
end
-
start
)
return
# We partition the samples in two nodes on a given dimension,
# with the middle point as a pivot
partition_node_indices
(
data_ptr
,
indices_ptr
,
start
,
mid
,
end
,
dim
,
n_dims
)
self
.
_point
=
data_ptr
+
mid
self
.
_left
=
consume
Node
()
self
.
_right
=
consume
Node
()
# Recursing on both partition
self
.
_left
.
build_node
(
NULL
,
data_ptr
,
indices_ptr
,
leaf_size
,
n_dims
,
next_dim
,
start
,
mid
,
counter
)
self
.
_right
.
build_node
(
NULL
,
data_ptr
,
indices_ptr
,
leaf_size
,
n_dims
,
next_dim
,
mid
,
end
,
counter
)
void
query
(
self
,
D_t
*
query_points
,
I_t
i
,
active
NeighborsHeap
heaps
,
):
cdef
:
I_t
j
,
k
D_t
dist
D_t
tmp
if
self
.
_is_leaf
:
# Computing all the euclideans distances here
for
j
in
range
(
self
.
_start
,
self
.
_end
):
dist
=
0
for
k
in
range
(
self
.
_n_dims
):
tmp
=
(
query_points
[
i
*
self
.
_n_dims
+
k
]
-
self
.
_data_ptr
[
self
.
_indices_ptr
[
j
]
*
self
.
_n_dims
+
k
]
)
dist
+=
tmp
*
tmp
# The heap is doing the smart work of keeping
# the closest points for each query point i
heaps
.
push
(
NULL
,
i
,
dist
,
self
.
_indices_ptr
[
j
])
return
# TODO: one can implement a pruning strategy here
self
.
_left
.
query
(
NULL
,
query_points
,
i
,
heaps
)
self
.
_right
.
query
(
NULL
,
query_points
,
i
,
heaps
)
cdef
cypclass
KDTree
:
cdef
cypclass
KDTree
:
"""A KDTree based on asynchronous and parallel computations.
"""A KDTree based on asynchronous and parallel computations.
...
@@ -435,6 +478,7 @@ cdef cypclass KDTree:
...
@@ -435,6 +478,7 @@ cdef cypclass KDTree:
cdef
I_t
i
cdef
I_t
i
cdef
I_t
n
=
X
.
shape
[
0
]
cdef
I_t
n
=
X
.
shape
[
0
]
cdef
I_t
d
=
X
.
shape
[
1
]
cdef
I_t
d
=
X
.
shape
[
1
]
cdef
I_t
initialised
=
0
self
.
_n
=
n
self
.
_n
=
n
self
.
_d
=
d
self
.
_d
=
d
...
@@ -445,15 +489,11 @@ cdef cypclass KDTree:
...
@@ -445,15 +489,11 @@ cdef cypclass KDTree:
for
i
in
range
(
n
):
for
i
in
range
(
n
):
self
.
_indices_ptr
[
i
]
=
i
self
.
_indices_ptr
[
i
]
=
i
# Recurvisely building the tree here
global
scheduler
global
scheduler
scheduler
=
Scheduler
()
scheduler
=
Scheduler
()
self
.
_recursive_build
()
void
_recursive_build
(
self
):
cdef
I_t
initialised
cdef
active
Counter
counter
=
consume
Counter
()
cdef
active
Counter
counter
=
consume
Counter
()
self
.
_root
=
consume
Node
()
self
.
_root
=
consume
Node
()
if
self
.
_root
is
NULL
:
if
self
.
_root
is
NULL
:
printf
(
"Error consuming node
\
n
"
)
printf
(
"Error consuming node
\
n
"
)
...
@@ -462,6 +502,9 @@ cdef cypclass KDTree:
...
@@ -462,6 +502,9 @@ cdef cypclass KDTree:
# are reified. When using those reified methods
# are reified. When using those reified methods
# a new argument is prepredend for a predicate,
# a new argument is prepredend for a predicate,
# which we aren't using using here, hence the extra NULL.
# which we aren't using using here, hence the extra NULL.
#
# Also using this separate method allowing using actors
# because __init__ can't be reified.
self
.
_root
.
build_node
(
NULL
,
self
.
_root
.
build_node
(
NULL
,
self
.
_data_ptr
,
self
.
_data_ptr
,
self
.
_indices_ptr
,
self
.
_indices_ptr
,
...
@@ -469,8 +512,8 @@ cdef cypclass KDTree:
...
@@ -469,8 +512,8 @@ cdef cypclass KDTree:
dim
=
0
,
start
=
0
,
end
=
self
.
_n
,
dim
=
0
,
start
=
0
,
end
=
self
.
_n
,
counter
=
counter
)
counter
=
counter
)
# Waiting for the tree construction to end
initialised
=
counter
.
value
(
NULL
).
getIntResult
()
# Somewhat similar to a thread barrier
while
(
initialised
<
self
.
_n
):
while
(
initialised
<
self
.
_n
):
initialised
=
counter
.
value
(
NULL
).
getIntResult
()
initialised
=
counter
.
value
(
NULL
).
getIntResult
()
...
@@ -505,10 +548,10 @@ cdef cypclass KDTree:
...
@@ -505,10 +548,10 @@ cdef cypclass KDTree:
while
(
completed_queries
<
total_n_pushes
):
while
(
completed_queries
<
total_n_pushes
):
completed_queries
=
heaps
.
n_pushes
(
NULL
).
getIntResult
()
completed_queries
=
heaps
.
n_pushes
(
NULL
).
getIntResult
()
#
heaps.sort(NULL)
heaps
.
sort
(
NULL
)
#
while not(heaps.is_sorted(NULL).getIntResult()):
while
not
(
heaps
.
is_sorted
(
NULL
).
getIntResult
()):
#
pass
pass
cdef
public
int
main
()
nogil
:
cdef
public
int
main
()
nogil
:
...
...
kdtree/query_poc.py
View file @
e8d90ccd
...
@@ -16,6 +16,7 @@ if __name__ == '__main__':
...
@@ -16,6 +16,7 @@ if __name__ == '__main__':
tree
=
kdtree
.
KDTree
(
X
,
leaf_size
=
256
)
tree
=
kdtree
.
KDTree
(
X
,
leaf_size
=
256
)
closests
=
np
.
zeros
((
n_query
,
k
),
dtype
=
np
.
int32
)
closests
=
np
.
zeros
((
n_query
,
k
),
dtype
=
np
.
int32
)
# tree.query(query_points, closests)
# There's currently a deadlock here
# skl_tree = KDTree(X, leaf_size=256)
tree
.
query
(
query_points
,
closests
)
# skl_closests = skl_tree.query(query_points, k=k, return_distance=False).astype(np.int32)
\ No newline at end of file
\ No newline at end of file
kdtree/tests/test_conf.py
View file @
e8d90ccd
...
@@ -5,8 +5,9 @@ from sklearn.neighbors import KDTree
...
@@ -5,8 +5,9 @@ from sklearn.neighbors import KDTree
@
pytest
.
mark
.
parametrize
(
"n"
,
[
10
,
100
,
1000
,
10000
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
10
,
100
,
1000
,
10000
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
10
,
100
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
10
,
100
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
1
,
2
,
5
,
10
])
@
pytest
.
mark
.
parametrize
(
"leaf_size"
,
[
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"leaf_size"
,
[
256
,
1024
])
def
test_against_sklearn
(
n
,
d
,
leaf_size
):
def
test_against_sklearn
(
n
,
d
,
k
,
leaf_size
):
np
.
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
X
=
np
.
random
.
rand
(
n
,
d
)
X
=
np
.
random
.
rand
(
n
,
d
)
query_points
=
np
.
random
.
rand
(
n
,
d
)
query_points
=
np
.
random
.
rand
(
n
,
d
)
...
@@ -14,10 +15,10 @@ def test_against_sklearn(n, d, leaf_size):
...
@@ -14,10 +15,10 @@ def test_against_sklearn(n, d, leaf_size):
tree
=
kdtree
.
KDTree
(
X
,
leaf_size
=
256
)
tree
=
kdtree
.
KDTree
(
X
,
leaf_size
=
256
)
skl_tree
=
KDTree
(
X
,
leaf_size
=
256
)
skl_tree
=
KDTree
(
X
,
leaf_size
=
256
)
closests
=
np
.
zeros
((
n
,
2
),
dtype
=
np
.
int32
)
closests
=
np
.
zeros
((
n
,
k
),
dtype
=
np
.
int32
)
tree
.
query
(
query_points
,
closests
)
tree
.
query
(
query_points
,
closests
)
skl_closests
=
skl_tree
.
query
(
query_points
,
return_distance
=
False
)
skl_closests
=
skl_tree
.
query
(
query_points
,
k
=
k
,
return_distance
=
False
).
astype
(
np
.
int32
)
# The back tracking part of the algorithm is not yet implemented
# The back tracking part of the algorithm is not yet implemented
# hence, we test for a almost equality
# hence, we test for a almost equality
assert
(
np
.
ndarray
.
flatten
(
closests
)
==
np
.
ndarray
.
flatten
(
skl_closests
)).
mean
()
>
0.9
np
.
testing
.
assert_equal
(
closests
,
skl_closests
)
\ No newline at end of file
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment