Commit ee1d67a4 authored by Julien Jerphanion's avatar Julien Jerphanion

Remove extra interface for partial sort

parent c7339e68
......@@ -4,7 +4,6 @@ cimport numpy as np
np.import_array()
from cython.view cimport array as cvarray
from runtime.runtime cimport BatchMailBox, NullResult, Scheduler
from libc.stdio cimport printf
......@@ -36,7 +35,7 @@ cdef extern from *:
}
};
template<class D, class I>
void partition_node_indices_inner(
void partition_node_indices(
const D *data,
I *node_indices,
const I &split_dim,
......@@ -51,7 +50,7 @@ cdef extern from *:
index_comparator);
}
"""
void partition_node_indices_inner[D, I](
void partition_node_indices[D, I](
D *data,
I *node_indices,
I split_dim,
......@@ -59,58 +58,6 @@ cdef extern from *:
I n_features,
I n_points) nogil except +
cdef I_t partition_node_indices(
D_t *data,
I_t *node_indices,
I_t split_dim,
I_t split_index,
I_t n_features,
I_t n_points) nogil except -1:
"""Partition points in the node into two equal-sized groups.
Upon return, the values in node_indices will be rearranged such that
(assuming numpy-style indexing):
data[node_indices[0:split_index], split_dim]
<= data[node_indices[split_index], split_dim]
and
data[node_indices[split_index], split_dim]
<= data[node_indices[split_index:n_points], split_dim]
The algorithm is essentially a partial in-place quicksort around a
set pivot.
Parameters
----------
data : D_t pointer
Pointer to a 2D array of the training data, of shape [N, n_features].
N must be greater than any of the values in node_indices.
node_indices : I_t pointer
Pointer to a 1D array of length n_points. This lists the indices of
each of the points within the current node. This will be modified
in-place.
split_dim : int
the dimension on which to split. This will usually be computed via
the routine ``find_node_split_dim``.
split_index : int
the index within node_indices around which to split the points.
n_features: int
the number of features (i.e columns) in the 2D array pointed by data.
n_points : int
the length of node_indices. This is also the number of points in
the original dataset.
Returns
-------
status : int
integer exit status. On return, the contents of node_indices are
modified as noted above.
"""
partition_node_indices_inner(
data,
node_indices,
split_dim,
split_index,
n_features,
n_points)
return 0
cdef cypclass Node activable:
"""A KDTree Node"""
......@@ -138,6 +85,7 @@ cdef cypclass Node activable:
cdef I_t i
cdef I_t next_dim = (dim + 1) % n_dims
cdef I_t nn = end - start
cdef I_t n_mid = nn // 2
cdef I_t split_index = (start + end) // 2
self.n_dims = n_dims
......@@ -149,7 +97,7 @@ cdef cypclass Node activable:
dim, start, end, split_index)
partition_node_indices(points + start,
indices + start, dim, nn // 2, n_dims, nn)
indices + start, dim, n_mid, n_dims, nn)
self.point = points + split_index
......@@ -190,11 +138,6 @@ cdef cypclass KDTree:
np.ndarray data_arr
np.ndarray idx_array_arr
# TODO: use memoryview from the user-provided numpy array
# and pointers for backend implementation.
# D_t[:, ::1] data
# I_t[::1] idx_array
active Node root
D_t *points
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment