Commit e8d90ccd authored by Julien Jerphanion's avatar Julien Jerphanion

Refactor

Simplify logic.
Add comments to explain motives.
parent 280c0904
...@@ -18,6 +18,8 @@ ctypedef double D_t ...@@ -18,6 +18,8 @@ ctypedef double D_t
cdef lock Scheduler scheduler cdef lock Scheduler scheduler
cdef D_t INF = 1e38 cdef D_t INF = 1e38
# NOTE: The following extern definition is used to interface
# std::nth_element, a robust partitioning algorithm, in Cython
cdef extern from *: cdef extern from *:
""" """
#include <algorithm> #include <algorithm>
...@@ -66,109 +68,13 @@ cdef extern from *: ...@@ -66,109 +68,13 @@ cdef extern from *:
I n_features I n_features
) nogil except + ) nogil except +
cdef cypclass Node activable:
"""A KDTree Node"""
D_t * _data_ptr
I_t * _indices_ptr
D_t * _point
I_t _n_dims
I_t _dim
I_t _start
I_t _end
bint _is_leaf
active Node _left
active Node _right
__init__(self):
self._active_result_class = WaitResult.construct
self._active_queue_class = consume BatchMailBox(scheduler)
self._left = NULL
self._right = NULL
self._is_leaf = False
void build_node(
self,
D_t * data_ptr,
I_t * indices_ptr,
I_t leaf_size,
I_t n_dims,
I_t dim,
I_t start,
I_t end,
active Counter counter,
):
cdef I_t i
cdef I_t next_dim = (dim + 1) % n_dims
cdef I_t mid = (start + end) // 2
self._data_ptr = data_ptr
self._indices_ptr = indices_ptr
self._dim = dim
self._n_dims = n_dims
self._start = start
self._end = end
if (end - start <= leaf_size):
self._is_leaf = True
counter.add(NULL, end - start)
return
partition_node_indices(data_ptr, indices_ptr, start, mid, end, dim, n_dims)
self._point = data_ptr + mid
self._left = consume Node()
self._right = consume Node()
self._left.build_node(NULL,
data_ptr, indices_ptr,
leaf_size, n_dims, next_dim,
start, mid, counter)
self._right.build_node(NULL,
data_ptr, indices_ptr,
leaf_size, n_dims, next_dim,
mid, end, counter)
void query(self,
D_t * query_points,
I_t i,
active NeighborsHeap heaps,
):
cdef:
I_t j, k, closest = -1
D_t dist = INF
D_t tmp
D_t min_distance = INF
if self._is_leaf:
for j in range(self._start, self._end):
dist = 0
for k in range(self._n_dims):
tmp = (
query_points[i * self._n_dims + k] -
self._data_ptr[self._indices_ptr[j] * self._n_dims + k]
)
dist += tmp * tmp
heaps.push(NULL, i, dist, j)
if dist < min_distance:
closest = j
min_distance = dist
return
# if query_points[ i * self._n_dims + self._dim] < self._point[self._dim]:
self._left.query(NULL, query_points, i, heaps)
# else:
self._right.query(NULL, query_points, i, heaps)
cdef cypclass Counter activable: cdef cypclass Counter activable:
""" A simple Counter. """ A simple Counter.
Useful for synchronisation, as it can be used as a barrier. This can be useful for synchronisation for the caller after
triggering the actors logic as it wait for the value of
the Coutner to reach a given one before moving on.
""" """
I_t _n I_t _n
...@@ -191,7 +97,11 @@ cdef cypclass Counter activable: ...@@ -191,7 +97,11 @@ cdef cypclass Counter activable:
cdef cypclass Container activable: cdef cypclass Container activable:
""" A simple wrapper of an array. """ A simple wrapper of an array.
Useful for synchronisation, as it can be used as a barrier. The wrapped array is passed by the initial caller which then
trigger the actors logic modifying it.
The initial caller can wait for specific number of update
before proceeding.
""" """
I_t * _array I_t * _array
...@@ -210,9 +120,10 @@ cdef cypclass Container activable: ...@@ -210,9 +120,10 @@ cdef cypclass Container activable:
self._n_updates += 1 self._n_updates += 1
self._array[idx] = value self._array[idx] = value
int get_n_updates(self): int n_updates(self):
return self._n_updates return self._n_updates
cdef inline void dual_swap(D_t* darr, I_t* iarr, I_t i1, I_t i2) nogil: cdef inline void dual_swap(D_t* darr, I_t* iarr, I_t i1, I_t i2) nogil:
"""swap the values at inex i1 and i2 of both darr and iarr""" """swap the values at inex i1 and i2 of both darr and iarr"""
cdef D_t dtmp = darr[i1] cdef D_t dtmp = darr[i1]
...@@ -288,24 +199,6 @@ cdef void _simultaneous_sort( ...@@ -288,24 +199,6 @@ cdef void _simultaneous_sort(
size - pivot_idx - 1) size - pivot_idx - 1)
cdef void _sort(
D_t* dist,
I_t* idx,
I_t n_rows,
I_t size,
) nogil:
"""simultaneously sort the distances and indices"""
cdef I_t row
for row in prange(n_rows,
nogil=True,
schedule="static",
num_threads=4):
_simultaneous_sort(
dist + row * size,
idx + row * size,
size)
cdef cypclass NeighborsHeap activable: cdef cypclass NeighborsHeap activable:
"""A max-heap structure to keep track of distances/indices of neighbors """A max-heap structure to keep track of distances/indices of neighbors
...@@ -317,6 +210,12 @@ cdef cypclass NeighborsHeap activable: ...@@ -317,6 +210,12 @@ cdef cypclass NeighborsHeap activable:
Taken and adapted from: Taken and adapted from:
https://github.com/scikit-learn/scikit-learn/blob/e4bb9fa86b0df873ad750b6d59090843d9d23d50/sklearn/neighbors/_binary_tree.pxi#L513 https://github.com/scikit-learn/scikit-learn/blob/e4bb9fa86b0df873ad750b6d59090843d9d23d50/sklearn/neighbors/_binary_tree.pxi#L513
The initial caller is responsible for providing the array of indices
which will be modified in place by the actors logic.
n_pushes and is_sorted can be used by the initial caller to know
when to pursue.
Parameters Parameters
---------- ----------
n_pts : int n_pts : int
...@@ -333,6 +232,7 @@ cdef cypclass NeighborsHeap activable: ...@@ -333,6 +232,7 @@ cdef cypclass NeighborsHeap activable:
bint _sorted bint _sorted
__init__(self, I_t * indices, I_t n_pts, I_t n_nbrs): __init__(self, I_t * indices, I_t n_pts, I_t n_nbrs):
cdef I_t i
self._n_pts = n_pts self._n_pts = n_pts
self._n_nbrs = n_nbrs self._n_nbrs = n_nbrs
self._active_result_class = WaitResult.construct self._active_result_class = WaitResult.construct
...@@ -342,22 +242,31 @@ cdef cypclass NeighborsHeap activable: ...@@ -342,22 +242,31 @@ cdef cypclass NeighborsHeap activable:
self._n_pushes = 0 self._n_pushes = 0
self._sorted = False self._sorted = False
# We can't use memset here
for i in range(n_pts * n_nbrs):
self._distances[i] = INF
void __dealloc__(self): void __dealloc__(self):
free(self._distances) free(self._distances)
void push(self, I_t row, D_t val, I_t i_val): void push(self, I_t row, D_t val, I_t i_val):
"""push (val, i_val) into the given row""" """push (val, i_val) into the given row"""
cdef I_t i, left_child_idx, right_child_idx, swap_idx cdef:
I_t i, left_child_idx, right_child_idx, swap_idx
# Getting the heap to use
I_t *indices = self._indices + row * self._n_nbrs
D_t *distances = self._distances + row * self._n_nbrs
self._n_pushes += 1 self._n_pushes += 1
# check if val should be in heap # check if val should be in heap
if val > self._distances[0]: if val > distances[0]:
return return
# insert val at position zero # insert val at position zero
self._distances[0] = val distances[0] = val
self._indices[0] = i_val indices[0] = i_val
# descend the heap, swapping values until the max heap criterion is met # descend the heap, swapping values until the max heap criterion is met
i = 0 i = 0
...@@ -368,42 +277,176 @@ cdef cypclass NeighborsHeap activable: ...@@ -368,42 +277,176 @@ cdef cypclass NeighborsHeap activable:
if left_child_idx >= self._n_nbrs: if left_child_idx >= self._n_nbrs:
break break
elif right_child_idx >= self._n_nbrs: elif right_child_idx >= self._n_nbrs:
if self._distances[left_child_idx] > val: if distances[left_child_idx] > val:
swap_idx = left_child_idx swap_idx = left_child_idx
else: else:
break break
elif self._distances[left_child_idx] >= self._distances[right_child_idx]: elif distances[left_child_idx] >= distances[right_child_idx]:
if val < self._distances[left_child_idx]: if val < distances[left_child_idx]:
swap_idx = left_child_idx swap_idx = left_child_idx
else: else:
break break
else: else:
if val < self._distances[right_child_idx]: if val < distances[right_child_idx]:
swap_idx = right_child_idx swap_idx = right_child_idx
else: else:
break break
self._distances[i] = self._distances[swap_idx] distances[i] = distances[swap_idx]
self._indices[i] = self._indices[swap_idx] indices[i] = indices[swap_idx]
i = swap_idx i = swap_idx
self._distances[i] = val distances[i] = val
self._indices[i] = i_val indices[i] = i_val
void sort(self): void sort(self):
_sort(self._distances, self._indices, # NOTE: Ideally we could sort results in parallel, but
self._n_pts, self._n_nbrs) # OpenMP threadpool and this runtime's aren't working
self._sorted = False # nicely together (using prange here would create OpenMP
# threads within workers, causing unexpected behavior.)
cdef I_t row
for row in range(self._n_pts):
# We use a function here to be able to recurse
_simultaneous_sort(
self._distances + row * self._n_nbrs,
self._indices + row * self._n_nbrs,
self._n_nbrs)
self._sorted = True
int n_pushes(self): int n_pushes(self):
return self._n_pushes return self._n_pushes
int is_sorted(self): int is_sorted(self):
# TODO: is there a support for returning bool via
# promises? As of now returning a int for making
# use of getIntResult
return 1 if self._sorted else 0 return 1 if self._sorted else 0
cdef cypclass Node activable:
"""A KDTree Node
Node delegate tasks to their children Nodes.
Some Nodes are set as Leaves when they are associated
to ``leaf_size`` or less points.
Leafs are terminal Nodes and do not have children.
"""
# Reference to the head of the allocated arrays
# data gets not modified via _data_ptr
D_t * _data_ptr
I_t * _indices_ptr
# The point the Node split on
D_t * _point
I_t _n_dims
I_t _dim
# Portion of _indices covered by the Node is:
# _indices_ptr[_start:_end]
I_t _start
I_t _end
bint _is_leaf
active Node _left
active Node _right
__init__(self):
self._active_result_class = WaitResult.construct
self._active_queue_class = consume BatchMailBox(scheduler)
self._left = NULL
self._right = NULL
self._is_leaf = False
# We use this to allow using actors for initialisation
# because __init__ can't be reified.
void build_node(
self,
D_t * data_ptr,
I_t * indices_ptr,
I_t leaf_size,
I_t n_dims,
I_t dim,
I_t start,
I_t end,
active Counter counter,
):
cdef I_t i
cdef I_t next_dim = (dim + 1) % n_dims
cdef I_t mid = (start + end) // 2
self._data_ptr = data_ptr
self._indices_ptr = indices_ptr
self._dim = dim
self._n_dims = n_dims
self._start = start
self._end = end
if (end - start <= leaf_size):
self._is_leaf = True
# Adding to the global counter the number
# of samples the leaf is responsible of
counter.add(NULL, end - start)
return
# We partition the samples in two nodes on a given dimension,
# with the middle point as a pivot
partition_node_indices(data_ptr, indices_ptr, start, mid, end, dim, n_dims)
self._point = data_ptr + mid
self._left = consume Node()
self._right = consume Node()
# Recursing on both partition
self._left.build_node(NULL,
data_ptr, indices_ptr,
leaf_size, n_dims, next_dim,
start, mid, counter)
self._right.build_node(NULL,
data_ptr, indices_ptr,
leaf_size, n_dims, next_dim,
mid, end, counter)
void query(self,
D_t * query_points,
I_t i,
active NeighborsHeap heaps,
):
cdef:
I_t j, k
D_t dist
D_t tmp
if self._is_leaf:
# Computing all the euclideans distances here
for j in range(self._start, self._end):
dist = 0
for k in range(self._n_dims):
tmp = (
query_points[i * self._n_dims + k] -
self._data_ptr[self._indices_ptr[j] * self._n_dims + k]
)
dist += tmp * tmp
# The heap is doing the smart work of keeping
# the closest points for each query point i
heaps.push(NULL, i, dist, self._indices_ptr[j])
return
# TODO: one can implement a pruning strategy here
self._left.query(NULL, query_points, i, heaps)
self._right.query(NULL, query_points, i, heaps)
cdef cypclass KDTree: cdef cypclass KDTree:
"""A KDTree based on asynchronous and parallel computations. """A KDTree based on asynchronous and parallel computations.
...@@ -435,6 +478,7 @@ cdef cypclass KDTree: ...@@ -435,6 +478,7 @@ cdef cypclass KDTree:
cdef I_t i cdef I_t i
cdef I_t n = X.shape[0] cdef I_t n = X.shape[0]
cdef I_t d = X.shape[1] cdef I_t d = X.shape[1]
cdef I_t initialised = 0
self._n = n self._n = n
self._d = d self._d = d
...@@ -445,15 +489,11 @@ cdef cypclass KDTree: ...@@ -445,15 +489,11 @@ cdef cypclass KDTree:
for i in range(n): for i in range(n):
self._indices_ptr[i] = i self._indices_ptr[i] = i
# Recurvisely building the tree here
global scheduler global scheduler
scheduler = Scheduler() scheduler = Scheduler()
self._recursive_build()
void _recursive_build(self):
cdef I_t initialised
cdef active Counter counter = consume Counter() cdef active Counter counter = consume Counter()
self._root = consume Node() self._root = consume Node()
if self._root is NULL: if self._root is NULL:
printf("Error consuming node\n") printf("Error consuming node\n")
...@@ -462,6 +502,9 @@ cdef cypclass KDTree: ...@@ -462,6 +502,9 @@ cdef cypclass KDTree:
# are reified. When using those reified methods # are reified. When using those reified methods
# a new argument is prepredend for a predicate, # a new argument is prepredend for a predicate,
# which we aren't using using here, hence the extra NULL. # which we aren't using using here, hence the extra NULL.
#
# Also using this separate method allowing using actors
# because __init__ can't be reified.
self._root.build_node(NULL, self._root.build_node(NULL,
self._data_ptr, self._data_ptr,
self._indices_ptr, self._indices_ptr,
...@@ -469,8 +512,8 @@ cdef cypclass KDTree: ...@@ -469,8 +512,8 @@ cdef cypclass KDTree:
dim=0, start=0, end=self._n, dim=0, start=0, end=self._n,
counter=counter) counter=counter)
# Waiting for the tree construction to end
initialised = counter.value(NULL).getIntResult() # Somewhat similar to a thread barrier
while(initialised < self._n): while(initialised < self._n):
initialised = counter.value(NULL).getIntResult() initialised = counter.value(NULL).getIntResult()
...@@ -505,10 +548,10 @@ cdef cypclass KDTree: ...@@ -505,10 +548,10 @@ cdef cypclass KDTree:
while(completed_queries < total_n_pushes): while(completed_queries < total_n_pushes):
completed_queries = heaps.n_pushes(NULL).getIntResult() completed_queries = heaps.n_pushes(NULL).getIntResult()
# heaps.sort(NULL) heaps.sort(NULL)
# while not(heaps.is_sorted(NULL).getIntResult()): while not(heaps.is_sorted(NULL).getIntResult()):
# pass pass
cdef public int main() nogil: cdef public int main() nogil:
......
...@@ -16,6 +16,7 @@ if __name__ == '__main__': ...@@ -16,6 +16,7 @@ if __name__ == '__main__':
tree = kdtree.KDTree(X, leaf_size=256) tree = kdtree.KDTree(X, leaf_size=256)
closests = np.zeros((n_query, k), dtype=np.int32) closests = np.zeros((n_query, k), dtype=np.int32)
# tree.query(query_points, closests)
# There's currently a deadlock here # skl_tree = KDTree(X, leaf_size=256)
tree.query(query_points, closests) # skl_closests = skl_tree.query(query_points, k=k, return_distance=False).astype(np.int32)
\ No newline at end of file \ No newline at end of file
...@@ -5,8 +5,9 @@ from sklearn.neighbors import KDTree ...@@ -5,8 +5,9 @@ from sklearn.neighbors import KDTree
@pytest.mark.parametrize("n", [10, 100, 1000, 10000]) @pytest.mark.parametrize("n", [10, 100, 1000, 10000])
@pytest.mark.parametrize("d", [10, 100]) @pytest.mark.parametrize("d", [10, 100])
@pytest.mark.parametrize("k", [1, 2, 5, 10])
@pytest.mark.parametrize("leaf_size", [256, 1024]) @pytest.mark.parametrize("leaf_size", [256, 1024])
def test_against_sklearn(n, d, leaf_size): def test_against_sklearn(n, d, k, leaf_size):
np.random.seed(1) np.random.seed(1)
X = np.random.rand(n, d) X = np.random.rand(n, d)
query_points = np.random.rand(n, d) query_points = np.random.rand(n, d)
...@@ -14,10 +15,10 @@ def test_against_sklearn(n, d, leaf_size): ...@@ -14,10 +15,10 @@ def test_against_sklearn(n, d, leaf_size):
tree = kdtree.KDTree(X, leaf_size=256) tree = kdtree.KDTree(X, leaf_size=256)
skl_tree = KDTree(X, leaf_size=256) skl_tree = KDTree(X, leaf_size=256)
closests = np.zeros((n, 2), dtype=np.int32) closests = np.zeros((n, k), dtype=np.int32)
tree.query(query_points, closests) tree.query(query_points, closests)
skl_closests = skl_tree.query(query_points, return_distance=False) skl_closests = skl_tree.query(query_points, k=k, return_distance=False).astype(np.int32)
# The back tracking part of the algorithm is not yet implemented # The back tracking part of the algorithm is not yet implemented
# hence, we test for a almost equality # hence, we test for a almost equality
assert (np.ndarray.flatten(closests) == np.ndarray.flatten(skl_closests)).mean() > 0.9 np.testing.assert_equal(closests, skl_closests)
\ No newline at end of file \ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment