Commit 4689a19c authored by Julien Jerphanion's avatar Julien Jerphanion

Use Heap to query multiples neighbours

parent 1dfa85af
......@@ -9,6 +9,8 @@ from runtime.runtime cimport BatchMailBox, NullResult, Scheduler, WaitResult
from libc.stdio cimport printf
from libc.stdlib cimport malloc, free
from cython.parallel import prange
## Types declaration
ctypedef int I_t
ctypedef double D_t
......@@ -132,7 +134,7 @@ cdef cypclass Node activable:
void query(self,
D_t * query_points,
I_t i,
active NeighborsHeap heap.
active NeighborsHeap heaps,
):
cdef:
I_t j, k, closest = -1
......@@ -150,19 +152,17 @@ cdef cypclass Node activable:
)
dist += tmp * tmp
heap.push
heaps.push(NULL, i, dist, j)
if dist < min_distance:
closest = j
min_distance = dist
container.update(NULL, i, closest)
return
if query_points[ i * self._n_dims + self._dim] < self._point[self._dim]:
self._left.query(NULL, query_points, i, container)
else:
self._right.query(NULL, query_points, i, container)
# if query_points[ i * self._n_dims + self._dim] < self._point[self._dim]:
self._left.query(NULL, query_points, i, heaps)
# else:
self._right.query(NULL, query_points, i, heaps)
cdef cypclass Counter activable:
......@@ -224,6 +224,88 @@ cdef inline void dual_swap(D_t* darr, I_t* iarr, I_t i1, I_t i2) nogil:
iarr[i2] = itmp
cdef void _simultaneous_sort(
D_t* dist,
I_t* idx,
I_t size
) nogil:
"""
Perform a recursive quicksort on the dist array, simultaneously
performing the same swaps on the idx array. The equivalent in
numpy (though quite a bit slower) is
def simultaneous_sort(dist, idx):
i = np.argsort(dist)
return dist[i], idx[i]
"""
cdef I_t pivot_idx, i, store_idx
cdef D_t pivot_val
# in the small-array case, do things efficiently
if size <= 1:
pass
elif size == 2:
if dist[0] > dist[1]:
dual_swap(dist, idx, 0, 1)
elif size == 3:
if dist[0] > dist[1]:
dual_swap(dist, idx, 0, 1)
if dist[1] > dist[2]:
dual_swap(dist, idx, 1, 2)
if dist[0] > dist[1]:
dual_swap(dist, idx, 0, 1)
else:
# Determine the pivot using the median-of-three rule.
# The smallest of the three is moved to the beginning of the array,
# the middle (the pivot value) is moved to the end, and the largest
# is moved to the pivot index.
pivot_idx = size // 2
if dist[0] > dist[size - 1]:
dual_swap(dist, idx, 0, size - 1)
if dist[size - 1] > dist[pivot_idx]:
dual_swap(dist, idx, size - 1, pivot_idx)
if dist[0] > dist[size - 1]:
dual_swap(dist, idx, 0, size - 1)
pivot_val = dist[size - 1]
# partition indices about pivot. At the end of this operation,
# pivot_idx will contain the pivot value, everything to the left
# will be smaller, and everything to the right will be larger.
store_idx = 0
for i in range(size - 1):
if dist[i] < pivot_val:
dual_swap(dist, idx, i, store_idx)
store_idx += 1
dual_swap(dist, idx, store_idx, size - 1)
pivot_idx = store_idx
# recursively sort each side of the pivot
if pivot_idx > 1:
_simultaneous_sort(dist, idx, pivot_idx)
if pivot_idx + 2 < size:
_simultaneous_sort(dist + pivot_idx + 1,
idx + pivot_idx + 1,
size - pivot_idx - 1)
cdef void _sort(
D_t* dist,
I_t* idx,
I_t n_rows,
I_t size,
) nogil:
"""simultaneously sort the distances and indices"""
cdef I_t row
for row in prange(n_rows,
nogil=True,
schedule="static",
num_threads=4):
_simultaneous_sort(
dist + row * size,
idx + row * size,
size)
cdef cypclass NeighborsHeap activable:
"""A max-heap structure to keep track of distances/indices of neighbors
......@@ -247,25 +329,28 @@ cdef cypclass NeighborsHeap activable:
I_t _n_pts
I_t _n_nbrs
I_t _n_pushes
bint _sorted
__init__(self, I_t n_pts, I_t n_nbrs):
__init__(self, I_t * indices, I_t n_pts, I_t n_nbrs):
self._n_pts = n_pts
self._n_nbrs = n_nbrs
self._active_result_class = WaitResult.construct
self._active_queue_class = consume BatchMailBox(scheduler)
self._distances = <D_t *> malloc(n_pts * n_nbrs * sizeof(D_t))
self._indices = <I_t *> malloc(n_pts * n_nbrs * sizeof(I_t))
self._indices = indices
self._n_pushes = 0
self._sorted = False
void __dealloc__(self):
free(self._indices)
free(self._distances)
void push(self, I_t row, D_t val, I_t i_val):
"""push (val, i_val) into the given row"""
cdef I_t i, left_child_idx, right_child_idx, swap_idx
self._n_pushes += 1
# check if val should be in heap
if val > self._distances[0]:
return
......@@ -307,78 +392,17 @@ cdef cypclass NeighborsHeap activable:
self._indices[i] = i_val
void _sort(self):
"""simultaneously sort the distances and indices"""
cdef I_t row
for row in range(self._n_pts):
self._simultaneous_sort(
self._distances + row * self._n_nbrs,
self._indices + row * self._n_nbrs,
self._n_nbrs)
void sort(self):
_sort(self._distances, self._indices,
self._n_pts, self._n_nbrs)
self._sorted = False
void _simultaneous_sort(
self,
D_t* dist,
I_t* idx,
I_t size):
"""
Perform a recursive quicksort on the dist array, simultaneously
performing the same swaps on the idx array. The equivalent in
numpy (though quite a bit slower) is
def simultaneous_sort(dist, idx):
i = np.argsort(dist)
return dist[i], idx[i]
"""
cdef I_t pivot_idx, i, store_idx
cdef D_t pivot_val
# in the small-array case, do things efficiently
if size <= 1:
pass
elif size == 2:
if dist[0] > dist[1]:
dual_swap(dist, idx, 0, 1)
elif size == 3:
if dist[0] > dist[1]:
dual_swap(dist, idx, 0, 1)
if dist[1] > dist[2]:
dual_swap(dist, idx, 1, 2)
if dist[0] > dist[1]:
dual_swap(dist, idx, 0, 1)
else:
# Determine the pivot using the median-of-three rule.
# The smallest of the three is moved to the beginning of the array,
# the middle (the pivot value) is moved to the end, and the largest
# is moved to the pivot index.
pivot_idx = size // 2
if dist[0] > dist[size - 1]:
dual_swap(dist, idx, 0, size - 1)
if dist[size - 1] > dist[pivot_idx]:
dual_swap(dist, idx, size - 1, pivot_idx)
if dist[0] > dist[size - 1]:
dual_swap(dist, idx, 0, size - 1)
pivot_val = dist[size - 1]
int n_pushes(self):
return self._n_pushes
# partition indices about pivot. At the end of this operation,
# pivot_idx will contain the pivot value, everything to the left
# will be smaller, and everything to the right will be larger.
store_idx = 0
for i in range(size - 1):
if dist[i] < pivot_val:
dual_swap(dist, idx, i, store_idx)
store_idx += 1
dual_swap(dist, idx, store_idx, size - 1)
pivot_idx = store_idx
int is_sorted(self):
return 1 if self._sorted else 0
# recursively sort each side of the pivot
if pivot_idx > 1:
self._simultaneous_sort(dist, idx, pivot_idx)
if pivot_idx + 2 < size:
self._simultaneous_sort(dist + pivot_idx + 1,
idx + pivot_idx + 1,
size - pivot_idx - 1)
cdef cypclass KDTree:
......@@ -466,18 +490,25 @@ cdef cypclass KDTree:
I_t completed_queries = 0
I_t i
I_t n_query = query_points.shape[0]
I_t n_neighbors = query_points.shape[1]
active Container closests_container
I_t n_neighbors = closests.shape[1]
I_t total_n_pushes = n_query * self._n
active NeighborsHeap heaps
closests_container = consume Container(<I_t *> closests.data, n_query)
heaps = consume NeighborsHeap(<I_t *> closests.data,
n_query,
n_neighbors)
for i in range(n_query):
self._root.query(NULL,
<D_t *> query_points.data,
i, closests_container)
self._root.query(NULL, <D_t *> query_points.data, i, heaps)
while(completed_queries < total_n_pushes):
completed_queries = heaps.n_pushes(NULL).getIntResult()
# heaps.sort(NULL)
while(completed_queries < n_query):
completed_queries = closests_container.get_n_updates(NULL).getIntResult()
# while not(heaps.is_sorted(NULL).getIntResult()):
# pass
cdef public int main() nogil:
......
import numpy as np
import pytest
import kdtree
from sklearn.neighbors import KDTree
if __name__ == '__main__':
n = 1000
n_query = 100
d = 10
k = 10
np.random.seed(1)
X = np.random.rand(n, d)
query_points = np.random.rand(n_query, d)
tree = kdtree.KDTree(X, leaf_size=256)
closests = np.zeros((n_query, k), dtype=np.int32)
# There's currently a deadlock here
tree.query(query_points, closests)
\ No newline at end of file
......@@ -14,8 +14,8 @@ def test_against_sklearn(n, d, leaf_size):
tree = kdtree.KDTree(X, leaf_size=256)
skl_tree = KDTree(X, leaf_size=256)
closests = np.zeros((n), dtype=np.int32)
tree.get_closest(query_points, closests)
closests = np.zeros((n, 2), dtype=np.int32)
tree.query(query_points, closests)
skl_closests = skl_tree.query(query_points, return_distance=False)
# The back tracking part of the algorithm is not yet implemented
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment