Commit 1dfa85af authored by Julien Jerphanion's avatar Julien Jerphanion
parent 2c760866
......@@ -129,10 +129,10 @@ cdef cypclass Node activable:
leaf_size, n_dims, next_dim,
mid, end, counter)
void get_closest(self,
D_t * query_points,
I_t i,
active Container container
void query(self,
D_t * query_points,
I_t i,
active NeighborsHeap heap.
):
cdef:
I_t j, k, closest = -1
......@@ -150,6 +150,8 @@ cdef cypclass Node activable:
)
dist += tmp * tmp
heap.push
if dist < min_distance:
closest = j
min_distance = dist
......@@ -158,9 +160,9 @@ cdef cypclass Node activable:
return
if query_points[ i * self._n_dims + self._dim] < self._point[self._dim]:
self._left.get_closest(NULL, query_points, i, container)
self._left.query(NULL, query_points, i, container)
else:
self._right.get_closest(NULL, query_points, i, container)
self._right.query(NULL, query_points, i, container)
cdef cypclass Counter activable:
......@@ -211,6 +213,173 @@ cdef cypclass Container activable:
int get_n_updates(self):
return self._n_updates
cdef inline void dual_swap(D_t* darr, I_t* iarr, I_t i1, I_t i2) nogil:
"""swap the values at inex i1 and i2 of both darr and iarr"""
cdef D_t dtmp = darr[i1]
darr[i1] = darr[i2]
darr[i2] = dtmp
cdef I_t itmp = iarr[i1]
iarr[i1] = iarr[i2]
iarr[i2] = itmp
cdef cypclass NeighborsHeap activable:
"""A max-heap structure to keep track of distances/indices of neighbors
This implements an efficient pre-allocated set of fixed-size heaps
for chasing neighbors, holding both an index and a distance.
When any row of the heap is full, adding an additional point will push
the furthest point off the heap.
Taken and adapted from:
https://github.com/scikit-learn/scikit-learn/blob/e4bb9fa86b0df873ad750b6d59090843d9d23d50/sklearn/neighbors/_binary_tree.pxi#L513
Parameters
----------
n_pts : int
the number of heaps to use
n_nbrs : int
the size of each heap.
"""
D_t *_distances
I_t *_indices
I_t _n_pts
I_t _n_nbrs
bint _sorted
__init__(self, I_t n_pts, I_t n_nbrs):
self._n_pts = n_pts
self._n_nbrs = n_nbrs
self._active_result_class = WaitResult.construct
self._active_queue_class = consume BatchMailBox(scheduler)
self._distances = <D_t *> malloc(n_pts * n_nbrs * sizeof(D_t))
self._indices = <I_t *> malloc(n_pts * n_nbrs * sizeof(I_t))
void __dealloc__(self):
free(self._indices)
free(self._distances)
void push(self, I_t row, D_t val, I_t i_val):
"""push (val, i_val) into the given row"""
cdef I_t i, left_child_idx, right_child_idx, swap_idx
# check if val should be in heap
if val > self._distances[0]:
return
# insert val at position zero
self._distances[0] = val
self._indices[0] = i_val
# descend the heap, swapping values until the max heap criterion is met
i = 0
while True:
left_child_idx = 2 * i + 1
right_child_idx = left_child_idx + 1
if left_child_idx >= self._n_nbrs:
break
elif right_child_idx >= self._n_nbrs:
if self._distances[left_child_idx] > val:
swap_idx = left_child_idx
else:
break
elif self._distances[left_child_idx] >= self._distances[right_child_idx]:
if val < self._distances[left_child_idx]:
swap_idx = left_child_idx
else:
break
else:
if val < self._distances[right_child_idx]:
swap_idx = right_child_idx
else:
break
self._distances[i] = self._distances[swap_idx]
self._indices[i] = self._indices[swap_idx]
i = swap_idx
self._distances[i] = val
self._indices[i] = i_val
void _sort(self):
"""simultaneously sort the distances and indices"""
cdef I_t row
for row in range(self._n_pts):
self._simultaneous_sort(
self._distances + row * self._n_nbrs,
self._indices + row * self._n_nbrs,
self._n_nbrs)
void _simultaneous_sort(
self,
D_t* dist,
I_t* idx,
I_t size):
"""
Perform a recursive quicksort on the dist array, simultaneously
performing the same swaps on the idx array. The equivalent in
numpy (though quite a bit slower) is
def simultaneous_sort(dist, idx):
i = np.argsort(dist)
return dist[i], idx[i]
"""
cdef I_t pivot_idx, i, store_idx
cdef D_t pivot_val
# in the small-array case, do things efficiently
if size <= 1:
pass
elif size == 2:
if dist[0] > dist[1]:
dual_swap(dist, idx, 0, 1)
elif size == 3:
if dist[0] > dist[1]:
dual_swap(dist, idx, 0, 1)
if dist[1] > dist[2]:
dual_swap(dist, idx, 1, 2)
if dist[0] > dist[1]:
dual_swap(dist, idx, 0, 1)
else:
# Determine the pivot using the median-of-three rule.
# The smallest of the three is moved to the beginning of the array,
# the middle (the pivot value) is moved to the end, and the largest
# is moved to the pivot index.
pivot_idx = size // 2
if dist[0] > dist[size - 1]:
dual_swap(dist, idx, 0, size - 1)
if dist[size - 1] > dist[pivot_idx]:
dual_swap(dist, idx, size - 1, pivot_idx)
if dist[0] > dist[size - 1]:
dual_swap(dist, idx, 0, size - 1)
pivot_val = dist[size - 1]
# partition indices about pivot. At the end of this operation,
# pivot_idx will contain the pivot value, everything to the left
# will be smaller, and everything to the right will be larger.
store_idx = 0
for i in range(size - 1):
if dist[i] < pivot_val:
dual_swap(dist, idx, i, store_idx)
store_idx += 1
dual_swap(dist, idx, store_idx, size - 1)
pivot_idx = store_idx
# recursively sort each side of the pivot
if pivot_idx > 1:
self._simultaneous_sort(dist, idx, pivot_idx)
if pivot_idx + 2 < size:
self._simultaneous_sort(dist + pivot_idx + 1,
idx + pivot_idx + 1,
size - pivot_idx - 1)
cdef cypclass KDTree:
"""A KDTree based on asynchronous and parallel computations.
......@@ -289,7 +458,7 @@ cdef cypclass KDTree:
free(self._indices_ptr)
void get_closest(self,
void query(self,
np.ndarray query_points, # IN
np.ndarray closests, # OUT
):
......@@ -297,21 +466,20 @@ cdef cypclass KDTree:
I_t completed_queries = 0
I_t i
I_t n_query = query_points.shape[0]
I_t n_neighbors = query_points.shape[1]
active Container closests_container
global scheduler
closests_container = consume Container(<I_t *> closests.data, n_query)
for i in range(n_query):
self._root.get_closest(NULL,
self._root.query(NULL,
<D_t *> query_points.data,
i, closests_container)
while(completed_queries < n_query):
completed_queries = closests_container.get_n_updates(NULL).getIntResult()
cdef public int main() nogil:
# Entry point for the compiled binary file
printf("empty public int main() nogil:")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment