Commit e8d90ccd authored by Julien Jerphanion's avatar Julien Jerphanion

Refactor

Simplify logic.
Add comments to explain motives.
parent 280c0904
......@@ -18,6 +18,8 @@ ctypedef double D_t
cdef lock Scheduler scheduler
cdef D_t INF = 1e38
# NOTE: The following extern definition is used to interface
# std::nth_element, a robust partitioning algorithm, in Cython
cdef extern from *:
"""
#include <algorithm>
......@@ -66,109 +68,13 @@ cdef extern from *:
I n_features
) nogil except +
cdef cypclass Node activable:
"""A KDTree Node"""
D_t * _data_ptr
I_t * _indices_ptr
D_t * _point
I_t _n_dims
I_t _dim
I_t _start
I_t _end
bint _is_leaf
active Node _left
active Node _right
__init__(self):
self._active_result_class = WaitResult.construct
self._active_queue_class = consume BatchMailBox(scheduler)
self._left = NULL
self._right = NULL
self._is_leaf = False
void build_node(
self,
D_t * data_ptr,
I_t * indices_ptr,
I_t leaf_size,
I_t n_dims,
I_t dim,
I_t start,
I_t end,
active Counter counter,
):
cdef I_t i
cdef I_t next_dim = (dim + 1) % n_dims
cdef I_t mid = (start + end) // 2
self._data_ptr = data_ptr
self._indices_ptr = indices_ptr
self._dim = dim
self._n_dims = n_dims
self._start = start
self._end = end
if (end - start <= leaf_size):
self._is_leaf = True
counter.add(NULL, end - start)
return
partition_node_indices(data_ptr, indices_ptr, start, mid, end, dim, n_dims)
self._point = data_ptr + mid
self._left = consume Node()
self._right = consume Node()
self._left.build_node(NULL,
data_ptr, indices_ptr,
leaf_size, n_dims, next_dim,
start, mid, counter)
self._right.build_node(NULL,
data_ptr, indices_ptr,
leaf_size, n_dims, next_dim,
mid, end, counter)
void query(self,
D_t * query_points,
I_t i,
active NeighborsHeap heaps,
):
cdef:
I_t j, k, closest = -1
D_t dist = INF
D_t tmp
D_t min_distance = INF
if self._is_leaf:
for j in range(self._start, self._end):
dist = 0
for k in range(self._n_dims):
tmp = (
query_points[i * self._n_dims + k] -
self._data_ptr[self._indices_ptr[j] * self._n_dims + k]
)
dist += tmp * tmp
heaps.push(NULL, i, dist, j)
if dist < min_distance:
closest = j
min_distance = dist
return
# if query_points[ i * self._n_dims + self._dim] < self._point[self._dim]:
self._left.query(NULL, query_points, i, heaps)
# else:
self._right.query(NULL, query_points, i, heaps)
cdef cypclass Counter activable:
""" A simple Counter.
Useful for synchronisation, as it can be used as a barrier.
This can be useful for synchronisation for the caller after
triggering the actors logic as it wait for the value of
the Coutner to reach a given one before moving on.
"""
I_t _n
......@@ -191,7 +97,11 @@ cdef cypclass Counter activable:
cdef cypclass Container activable:
""" A simple wrapper of an array.
Useful for synchronisation, as it can be used as a barrier.
The wrapped array is passed by the initial caller which then
trigger the actors logic modifying it.
The initial caller can wait for specific number of update
before proceeding.
"""
I_t * _array
......@@ -210,9 +120,10 @@ cdef cypclass Container activable:
self._n_updates += 1
self._array[idx] = value
int get_n_updates(self):
int n_updates(self):
return self._n_updates
cdef inline void dual_swap(D_t* darr, I_t* iarr, I_t i1, I_t i2) nogil:
"""swap the values at inex i1 and i2 of both darr and iarr"""
cdef D_t dtmp = darr[i1]
......@@ -288,24 +199,6 @@ cdef void _simultaneous_sort(
size - pivot_idx - 1)
cdef void _sort(
D_t* dist,
I_t* idx,
I_t n_rows,
I_t size,
) nogil:
"""simultaneously sort the distances and indices"""
cdef I_t row
for row in prange(n_rows,
nogil=True,
schedule="static",
num_threads=4):
_simultaneous_sort(
dist + row * size,
idx + row * size,
size)
cdef cypclass NeighborsHeap activable:
"""A max-heap structure to keep track of distances/indices of neighbors
......@@ -317,6 +210,12 @@ cdef cypclass NeighborsHeap activable:
Taken and adapted from:
https://github.com/scikit-learn/scikit-learn/blob/e4bb9fa86b0df873ad750b6d59090843d9d23d50/sklearn/neighbors/_binary_tree.pxi#L513
The initial caller is responsible for providing the array of indices
which will be modified in place by the actors logic.
n_pushes and is_sorted can be used by the initial caller to know
when to pursue.
Parameters
----------
n_pts : int
......@@ -333,6 +232,7 @@ cdef cypclass NeighborsHeap activable:
bint _sorted
__init__(self, I_t * indices, I_t n_pts, I_t n_nbrs):
cdef I_t i
self._n_pts = n_pts
self._n_nbrs = n_nbrs
self._active_result_class = WaitResult.construct
......@@ -342,22 +242,31 @@ cdef cypclass NeighborsHeap activable:
self._n_pushes = 0
self._sorted = False
# We can't use memset here
for i in range(n_pts * n_nbrs):
self._distances[i] = INF
void __dealloc__(self):
free(self._distances)
void push(self, I_t row, D_t val, I_t i_val):
"""push (val, i_val) into the given row"""
cdef I_t i, left_child_idx, right_child_idx, swap_idx
cdef:
I_t i, left_child_idx, right_child_idx, swap_idx
# Getting the heap to use
I_t *indices = self._indices + row * self._n_nbrs
D_t *distances = self._distances + row * self._n_nbrs
self._n_pushes += 1
# check if val should be in heap
if val > self._distances[0]:
if val > distances[0]:
return
# insert val at position zero
self._distances[0] = val
self._indices[0] = i_val
distances[0] = val
indices[0] = i_val
# descend the heap, swapping values until the max heap criterion is met
i = 0
......@@ -368,42 +277,176 @@ cdef cypclass NeighborsHeap activable:
if left_child_idx >= self._n_nbrs:
break
elif right_child_idx >= self._n_nbrs:
if self._distances[left_child_idx] > val:
if distances[left_child_idx] > val:
swap_idx = left_child_idx
else:
break
elif self._distances[left_child_idx] >= self._distances[right_child_idx]:
if val < self._distances[left_child_idx]:
elif distances[left_child_idx] >= distances[right_child_idx]:
if val < distances[left_child_idx]:
swap_idx = left_child_idx
else:
break
else:
if val < self._distances[right_child_idx]:
if val < distances[right_child_idx]:
swap_idx = right_child_idx
else:
break
self._distances[i] = self._distances[swap_idx]
self._indices[i] = self._indices[swap_idx]
distances[i] = distances[swap_idx]
indices[i] = indices[swap_idx]
i = swap_idx
self._distances[i] = val
self._indices[i] = i_val
distances[i] = val
indices[i] = i_val
void sort(self):
_sort(self._distances, self._indices,
self._n_pts, self._n_nbrs)
self._sorted = False
# NOTE: Ideally we could sort results in parallel, but
# OpenMP threadpool and this runtime's aren't working
# nicely together (using prange here would create OpenMP
# threads within workers, causing unexpected behavior.)
cdef I_t row
for row in range(self._n_pts):
# We use a function here to be able to recurse
_simultaneous_sort(
self._distances + row * self._n_nbrs,
self._indices + row * self._n_nbrs,
self._n_nbrs)
self._sorted = True
int n_pushes(self):
return self._n_pushes
int is_sorted(self):
# TODO: is there a support for returning bool via
# promises? As of now returning a int for making
# use of getIntResult
return 1 if self._sorted else 0
cdef cypclass Node activable:
"""A KDTree Node
Node delegate tasks to their children Nodes.
Some Nodes are set as Leaves when they are associated
to ``leaf_size`` or less points.
Leafs are terminal Nodes and do not have children.
"""
# Reference to the head of the allocated arrays
# data gets not modified via _data_ptr
D_t * _data_ptr
I_t * _indices_ptr
# The point the Node split on
D_t * _point
I_t _n_dims
I_t _dim
# Portion of _indices covered by the Node is:
# _indices_ptr[_start:_end]
I_t _start
I_t _end
bint _is_leaf
active Node _left
active Node _right
__init__(self):
self._active_result_class = WaitResult.construct
self._active_queue_class = consume BatchMailBox(scheduler)
self._left = NULL
self._right = NULL
self._is_leaf = False
# We use this to allow using actors for initialisation
# because __init__ can't be reified.
void build_node(
self,
D_t * data_ptr,
I_t * indices_ptr,
I_t leaf_size,
I_t n_dims,
I_t dim,
I_t start,
I_t end,
active Counter counter,
):
cdef I_t i
cdef I_t next_dim = (dim + 1) % n_dims
cdef I_t mid = (start + end) // 2
self._data_ptr = data_ptr
self._indices_ptr = indices_ptr
self._dim = dim
self._n_dims = n_dims
self._start = start
self._end = end
if (end - start <= leaf_size):
self._is_leaf = True
# Adding to the global counter the number
# of samples the leaf is responsible of
counter.add(NULL, end - start)
return
# We partition the samples in two nodes on a given dimension,
# with the middle point as a pivot
partition_node_indices(data_ptr, indices_ptr, start, mid, end, dim, n_dims)
self._point = data_ptr + mid
self._left = consume Node()
self._right = consume Node()
# Recursing on both partition
self._left.build_node(NULL,
data_ptr, indices_ptr,
leaf_size, n_dims, next_dim,
start, mid, counter)
self._right.build_node(NULL,
data_ptr, indices_ptr,
leaf_size, n_dims, next_dim,
mid, end, counter)
void query(self,
D_t * query_points,
I_t i,
active NeighborsHeap heaps,
):
cdef:
I_t j, k
D_t dist
D_t tmp
if self._is_leaf:
# Computing all the euclideans distances here
for j in range(self._start, self._end):
dist = 0
for k in range(self._n_dims):
tmp = (
query_points[i * self._n_dims + k] -
self._data_ptr[self._indices_ptr[j] * self._n_dims + k]
)
dist += tmp * tmp
# The heap is doing the smart work of keeping
# the closest points for each query point i
heaps.push(NULL, i, dist, self._indices_ptr[j])
return
# TODO: one can implement a pruning strategy here
self._left.query(NULL, query_points, i, heaps)
self._right.query(NULL, query_points, i, heaps)
cdef cypclass KDTree:
"""A KDTree based on asynchronous and parallel computations.
......@@ -435,6 +478,7 @@ cdef cypclass KDTree:
cdef I_t i
cdef I_t n = X.shape[0]
cdef I_t d = X.shape[1]
cdef I_t initialised = 0
self._n = n
self._d = d
......@@ -445,15 +489,11 @@ cdef cypclass KDTree:
for i in range(n):
self._indices_ptr[i] = i
# Recurvisely building the tree here
global scheduler
scheduler = Scheduler()
self._recursive_build()
void _recursive_build(self):
cdef I_t initialised
cdef active Counter counter = consume Counter()
self._root = consume Node()
if self._root is NULL:
printf("Error consuming node\n")
......@@ -462,6 +502,9 @@ cdef cypclass KDTree:
# are reified. When using those reified methods
# a new argument is prepredend for a predicate,
# which we aren't using using here, hence the extra NULL.
#
# Also using this separate method allowing using actors
# because __init__ can't be reified.
self._root.build_node(NULL,
self._data_ptr,
self._indices_ptr,
......@@ -469,8 +512,8 @@ cdef cypclass KDTree:
dim=0, start=0, end=self._n,
counter=counter)
initialised = counter.value(NULL).getIntResult()
# Waiting for the tree construction to end
# Somewhat similar to a thread barrier
while(initialised < self._n):
initialised = counter.value(NULL).getIntResult()
......@@ -505,10 +548,10 @@ cdef cypclass KDTree:
while(completed_queries < total_n_pushes):
completed_queries = heaps.n_pushes(NULL).getIntResult()
# heaps.sort(NULL)
heaps.sort(NULL)
# while not(heaps.is_sorted(NULL).getIntResult()):
# pass
while not(heaps.is_sorted(NULL).getIntResult()):
pass
cdef public int main() nogil:
......
......@@ -16,6 +16,7 @@ if __name__ == '__main__':
tree = kdtree.KDTree(X, leaf_size=256)
closests = np.zeros((n_query, k), dtype=np.int32)
# tree.query(query_points, closests)
# There's currently a deadlock here
tree.query(query_points, closests)
\ No newline at end of file
# skl_tree = KDTree(X, leaf_size=256)
# skl_closests = skl_tree.query(query_points, k=k, return_distance=False).astype(np.int32)
\ No newline at end of file
......@@ -5,8 +5,9 @@ from sklearn.neighbors import KDTree
@pytest.mark.parametrize("n", [10, 100, 1000, 10000])
@pytest.mark.parametrize("d", [10, 100])
@pytest.mark.parametrize("k", [1, 2, 5, 10])
@pytest.mark.parametrize("leaf_size", [256, 1024])
def test_against_sklearn(n, d, leaf_size):
def test_against_sklearn(n, d, k, leaf_size):
np.random.seed(1)
X = np.random.rand(n, d)
query_points = np.random.rand(n, d)
......@@ -14,10 +15,10 @@ def test_against_sklearn(n, d, leaf_size):
tree = kdtree.KDTree(X, leaf_size=256)
skl_tree = KDTree(X, leaf_size=256)
closests = np.zeros((n, 2), dtype=np.int32)
closests = np.zeros((n, k), dtype=np.int32)
tree.query(query_points, closests)
skl_closests = skl_tree.query(query_points, return_distance=False)
skl_closests = skl_tree.query(query_points, k=k, return_distance=False).astype(np.int32)
# The back tracking part of the algorithm is not yet implemented
# hence, we test for a almost equality
assert (np.ndarray.flatten(closests) == np.ndarray.flatten(skl_closests)).mean() > 0.9
\ No newline at end of file
np.testing.assert_equal(closests, skl_closests)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment