Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython_plus_experiments
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Julien Jerphanion
cython_plus_experiments
Commits
4689a19c
Commit
4689a19c
authored
Jun 11, 2021
by
Julien Jerphanion
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use Heap to query multiples neighbours
parent
1dfa85af
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
150 additions
and
98 deletions
+150
-98
kdtree/kdtree.pyx
kdtree/kdtree.pyx
+127
-96
kdtree/query_poc.py
kdtree/query_poc.py
+21
-0
kdtree/tests/test_conf.py
kdtree/tests/test_conf.py
+2
-2
No files found.
kdtree/kdtree.pyx
View file @
4689a19c
...
@@ -9,6 +9,8 @@ from runtime.runtime cimport BatchMailBox, NullResult, Scheduler, WaitResult
...
@@ -9,6 +9,8 @@ from runtime.runtime cimport BatchMailBox, NullResult, Scheduler, WaitResult
from
libc.stdio
cimport
printf
from
libc.stdio
cimport
printf
from
libc.stdlib
cimport
malloc
,
free
from
libc.stdlib
cimport
malloc
,
free
from
cython.parallel
import
prange
## Types declaration
## Types declaration
ctypedef
int
I_t
ctypedef
int
I_t
ctypedef
double
D_t
ctypedef
double
D_t
...
@@ -132,7 +134,7 @@ cdef cypclass Node activable:
...
@@ -132,7 +134,7 @@ cdef cypclass Node activable:
void
query
(
self
,
void
query
(
self
,
D_t
*
query_points
,
D_t
*
query_points
,
I_t
i
,
I_t
i
,
active
NeighborsHeap
heap
.
active
NeighborsHeap
heap
s
,
):
):
cdef
:
cdef
:
I_t
j
,
k
,
closest
=
-
1
I_t
j
,
k
,
closest
=
-
1
...
@@ -150,19 +152,17 @@ cdef cypclass Node activable:
...
@@ -150,19 +152,17 @@ cdef cypclass Node activable:
)
)
dist
+=
tmp
*
tmp
dist
+=
tmp
*
tmp
heap
.
push
heap
s
.
push
(
NULL
,
i
,
dist
,
j
)
if
dist
<
min_distance
:
if
dist
<
min_distance
:
closest
=
j
closest
=
j
min_distance
=
dist
min_distance
=
dist
container
.
update
(
NULL
,
i
,
closest
)
return
return
if
query_points
[
i
*
self
.
_n_dims
+
self
.
_dim
]
<
self
.
_point
[
self
.
_dim
]:
#
if query_points[ i * self._n_dims + self._dim] < self._point[self._dim]:
self
.
_left
.
query
(
NULL
,
query_points
,
i
,
container
)
self
.
_left
.
query
(
NULL
,
query_points
,
i
,
heaps
)
else
:
#
else:
self
.
_right
.
query
(
NULL
,
query_points
,
i
,
container
)
self
.
_right
.
query
(
NULL
,
query_points
,
i
,
heaps
)
cdef
cypclass
Counter
activable
:
cdef
cypclass
Counter
activable
:
...
@@ -224,6 +224,88 @@ cdef inline void dual_swap(D_t* darr, I_t* iarr, I_t i1, I_t i2) nogil:
...
@@ -224,6 +224,88 @@ cdef inline void dual_swap(D_t* darr, I_t* iarr, I_t i1, I_t i2) nogil:
iarr
[
i2
]
=
itmp
iarr
[
i2
]
=
itmp
cdef
void
_simultaneous_sort
(
D_t
*
dist
,
I_t
*
idx
,
I_t
size
)
nogil
:
"""
Perform a recursive quicksort on the dist array, simultaneously
performing the same swaps on the idx array. The equivalent in
numpy (though quite a bit slower) is
def simultaneous_sort(dist, idx):
i = np.argsort(dist)
return dist[i], idx[i]
"""
cdef
I_t
pivot_idx
,
i
,
store_idx
cdef
D_t
pivot_val
# in the small-array case, do things efficiently
if
size
<=
1
:
pass
elif
size
==
2
:
if
dist
[
0
]
>
dist
[
1
]:
dual_swap
(
dist
,
idx
,
0
,
1
)
elif
size
==
3
:
if
dist
[
0
]
>
dist
[
1
]:
dual_swap
(
dist
,
idx
,
0
,
1
)
if
dist
[
1
]
>
dist
[
2
]:
dual_swap
(
dist
,
idx
,
1
,
2
)
if
dist
[
0
]
>
dist
[
1
]:
dual_swap
(
dist
,
idx
,
0
,
1
)
else
:
# Determine the pivot using the median-of-three rule.
# The smallest of the three is moved to the beginning of the array,
# the middle (the pivot value) is moved to the end, and the largest
# is moved to the pivot index.
pivot_idx
=
size
//
2
if
dist
[
0
]
>
dist
[
size
-
1
]:
dual_swap
(
dist
,
idx
,
0
,
size
-
1
)
if
dist
[
size
-
1
]
>
dist
[
pivot_idx
]:
dual_swap
(
dist
,
idx
,
size
-
1
,
pivot_idx
)
if
dist
[
0
]
>
dist
[
size
-
1
]:
dual_swap
(
dist
,
idx
,
0
,
size
-
1
)
pivot_val
=
dist
[
size
-
1
]
# partition indices about pivot. At the end of this operation,
# pivot_idx will contain the pivot value, everything to the left
# will be smaller, and everything to the right will be larger.
store_idx
=
0
for
i
in
range
(
size
-
1
):
if
dist
[
i
]
<
pivot_val
:
dual_swap
(
dist
,
idx
,
i
,
store_idx
)
store_idx
+=
1
dual_swap
(
dist
,
idx
,
store_idx
,
size
-
1
)
pivot_idx
=
store_idx
# recursively sort each side of the pivot
if
pivot_idx
>
1
:
_simultaneous_sort
(
dist
,
idx
,
pivot_idx
)
if
pivot_idx
+
2
<
size
:
_simultaneous_sort
(
dist
+
pivot_idx
+
1
,
idx
+
pivot_idx
+
1
,
size
-
pivot_idx
-
1
)
cdef
void
_sort
(
D_t
*
dist
,
I_t
*
idx
,
I_t
n_rows
,
I_t
size
,
)
nogil
:
"""simultaneously sort the distances and indices"""
cdef
I_t
row
for
row
in
prange
(
n_rows
,
nogil
=
True
,
schedule
=
"static"
,
num_threads
=
4
):
_simultaneous_sort
(
dist
+
row
*
size
,
idx
+
row
*
size
,
size
)
cdef
cypclass
NeighborsHeap
activable
:
cdef
cypclass
NeighborsHeap
activable
:
"""A max-heap structure to keep track of distances/indices of neighbors
"""A max-heap structure to keep track of distances/indices of neighbors
...
@@ -247,25 +329,28 @@ cdef cypclass NeighborsHeap activable:
...
@@ -247,25 +329,28 @@ cdef cypclass NeighborsHeap activable:
I_t
_n_pts
I_t
_n_pts
I_t
_n_nbrs
I_t
_n_nbrs
I_t
_n_pushes
bint
_sorted
bint
_sorted
__init__
(
self
,
I_t
n_pts
,
I_t
n_nbrs
):
__init__
(
self
,
I_t
*
indices
,
I_t
n_pts
,
I_t
n_nbrs
):
self
.
_n_pts
=
n_pts
self
.
_n_pts
=
n_pts
self
.
_n_nbrs
=
n_nbrs
self
.
_n_nbrs
=
n_nbrs
self
.
_active_result_class
=
WaitResult
.
construct
self
.
_active_result_class
=
WaitResult
.
construct
self
.
_active_queue_class
=
consume
BatchMailBox
(
scheduler
)
self
.
_active_queue_class
=
consume
BatchMailBox
(
scheduler
)
self
.
_distances
=
<
D_t
*>
malloc
(
n_pts
*
n_nbrs
*
sizeof
(
D_t
))
self
.
_distances
=
<
D_t
*>
malloc
(
n_pts
*
n_nbrs
*
sizeof
(
D_t
))
self
.
_indices
=
<
I_t
*>
malloc
(
n_pts
*
n_nbrs
*
sizeof
(
I_t
))
self
.
_indices
=
indices
self
.
_n_pushes
=
0
self
.
_sorted
=
False
void
__dealloc__
(
self
):
void
__dealloc__
(
self
):
free
(
self
.
_indices
)
free
(
self
.
_distances
)
free
(
self
.
_distances
)
void
push
(
self
,
I_t
row
,
D_t
val
,
I_t
i_val
):
void
push
(
self
,
I_t
row
,
D_t
val
,
I_t
i_val
):
"""push (val, i_val) into the given row"""
"""push (val, i_val) into the given row"""
cdef
I_t
i
,
left_child_idx
,
right_child_idx
,
swap_idx
cdef
I_t
i
,
left_child_idx
,
right_child_idx
,
swap_idx
self
.
_n_pushes
+=
1
# check if val should be in heap
# check if val should be in heap
if
val
>
self
.
_distances
[
0
]:
if
val
>
self
.
_distances
[
0
]:
return
return
...
@@ -307,78 +392,17 @@ cdef cypclass NeighborsHeap activable:
...
@@ -307,78 +392,17 @@ cdef cypclass NeighborsHeap activable:
self
.
_indices
[
i
]
=
i_val
self
.
_indices
[
i
]
=
i_val
void
_sort
(
self
):
void
sort
(
self
):
"""simultaneously sort the distances and indices"""
_sort
(
self
.
_distances
,
self
.
_indices
,
cdef
I_t
row
self
.
_n_pts
,
self
.
_n_nbrs
)
for
row
in
range
(
self
.
_n_pts
):
self
.
_sorted
=
False
self
.
_simultaneous_sort
(
self
.
_distances
+
row
*
self
.
_n_nbrs
,
self
.
_indices
+
row
*
self
.
_n_nbrs
,
self
.
_n_nbrs
)
void
_simultaneous_sort
(
int
n_pushes
(
self
):
self
,
return
self
.
_n_pushes
D_t
*
dist
,
I_t
*
idx
,
I_t
size
):
"""
Perform a recursive quicksort on the dist array, simultaneously
performing the same swaps on the idx array. The equivalent in
numpy (though quite a bit slower) is
def simultaneous_sort(dist, idx):
i = np.argsort(dist)
return dist[i], idx[i]
"""
cdef
I_t
pivot_idx
,
i
,
store_idx
cdef
D_t
pivot_val
# in the small-array case, do things efficiently
if
size
<=
1
:
pass
elif
size
==
2
:
if
dist
[
0
]
>
dist
[
1
]:
dual_swap
(
dist
,
idx
,
0
,
1
)
elif
size
==
3
:
if
dist
[
0
]
>
dist
[
1
]:
dual_swap
(
dist
,
idx
,
0
,
1
)
if
dist
[
1
]
>
dist
[
2
]:
dual_swap
(
dist
,
idx
,
1
,
2
)
if
dist
[
0
]
>
dist
[
1
]:
dual_swap
(
dist
,
idx
,
0
,
1
)
else
:
# Determine the pivot using the median-of-three rule.
# The smallest of the three is moved to the beginning of the array,
# the middle (the pivot value) is moved to the end, and the largest
# is moved to the pivot index.
pivot_idx
=
size
//
2
if
dist
[
0
]
>
dist
[
size
-
1
]:
dual_swap
(
dist
,
idx
,
0
,
size
-
1
)
if
dist
[
size
-
1
]
>
dist
[
pivot_idx
]:
dual_swap
(
dist
,
idx
,
size
-
1
,
pivot_idx
)
if
dist
[
0
]
>
dist
[
size
-
1
]:
dual_swap
(
dist
,
idx
,
0
,
size
-
1
)
pivot_val
=
dist
[
size
-
1
]
# partition indices about pivot. At the end of this operation,
int
is_sorted
(
self
):
# pivot_idx will contain the pivot value, everything to the left
return
1
if
self
.
_sorted
else
0
# will be smaller, and everything to the right will be larger.
store_idx
=
0
for
i
in
range
(
size
-
1
):
if
dist
[
i
]
<
pivot_val
:
dual_swap
(
dist
,
idx
,
i
,
store_idx
)
store_idx
+=
1
dual_swap
(
dist
,
idx
,
store_idx
,
size
-
1
)
pivot_idx
=
store_idx
# recursively sort each side of the pivot
if
pivot_idx
>
1
:
self
.
_simultaneous_sort
(
dist
,
idx
,
pivot_idx
)
if
pivot_idx
+
2
<
size
:
self
.
_simultaneous_sort
(
dist
+
pivot_idx
+
1
,
idx
+
pivot_idx
+
1
,
size
-
pivot_idx
-
1
)
cdef
cypclass
KDTree
:
cdef
cypclass
KDTree
:
...
@@ -466,18 +490,25 @@ cdef cypclass KDTree:
...
@@ -466,18 +490,25 @@ cdef cypclass KDTree:
I_t
completed_queries
=
0
I_t
completed_queries
=
0
I_t
i
I_t
i
I_t
n_query
=
query_points
.
shape
[
0
]
I_t
n_query
=
query_points
.
shape
[
0
]
I_t
n_neighbors
=
query_points
.
shape
[
1
]
I_t
n_neighbors
=
closests
.
shape
[
1
]
active
Container
closests_container
I_t
total_n_pushes
=
n_query
*
self
.
_n
active
NeighborsHeap
heaps
closests_container
=
consume
Container
(
<
I_t
*>
closests
.
data
,
n_query
)
heaps
=
consume
NeighborsHeap
(
<
I_t
*>
closests
.
data
,
n_query
,
n_neighbors
)
for
i
in
range
(
n_query
):
for
i
in
range
(
n_query
):
self
.
_root
.
query
(
NULL
,
self
.
_root
.
query
(
NULL
,
<
D_t
*>
query_points
.
data
,
i
,
heaps
)
<
D_t
*>
query_points
.
data
,
i
,
closests_container
)
while
(
completed_queries
<
total_n_pushes
):
completed_queries
=
heaps
.
n_pushes
(
NULL
).
getIntResult
()
# heaps.sort(NULL)
while
(
completed_queries
<
n_query
):
# while not(heaps.is_sorted(NULL).getIntResult()
):
completed_queries
=
closests_container
.
get_n_updates
(
NULL
).
getIntResult
()
# pass
cdef
public
int
main
()
nogil
:
cdef
public
int
main
()
nogil
:
...
...
kdtree/query_poc.py
0 → 100644
View file @
4689a19c
import
numpy
as
np
import
pytest
import
kdtree
from
sklearn.neighbors
import
KDTree
if
__name__
==
'__main__'
:
n
=
1000
n_query
=
100
d
=
10
k
=
10
np
.
random
.
seed
(
1
)
X
=
np
.
random
.
rand
(
n
,
d
)
query_points
=
np
.
random
.
rand
(
n_query
,
d
)
tree
=
kdtree
.
KDTree
(
X
,
leaf_size
=
256
)
closests
=
np
.
zeros
((
n_query
,
k
),
dtype
=
np
.
int32
)
# There's currently a deadlock here
tree
.
query
(
query_points
,
closests
)
\ No newline at end of file
kdtree/tests/test_conf.py
View file @
4689a19c
...
@@ -14,8 +14,8 @@ def test_against_sklearn(n, d, leaf_size):
...
@@ -14,8 +14,8 @@ def test_against_sklearn(n, d, leaf_size):
tree
=
kdtree
.
KDTree
(
X
,
leaf_size
=
256
)
tree
=
kdtree
.
KDTree
(
X
,
leaf_size
=
256
)
skl_tree
=
KDTree
(
X
,
leaf_size
=
256
)
skl_tree
=
KDTree
(
X
,
leaf_size
=
256
)
closests
=
np
.
zeros
((
n
),
dtype
=
np
.
int32
)
closests
=
np
.
zeros
((
n
,
2
),
dtype
=
np
.
int32
)
tree
.
get_closest
(
query_points
,
closests
)
tree
.
query
(
query_points
,
closests
)
skl_closests
=
skl_tree
.
query
(
query_points
,
return_distance
=
False
)
skl_closests
=
skl_tree
.
query
(
query_points
,
return_distance
=
False
)
# The back tracking part of the algorithm is not yet implemented
# The back tracking part of the algorithm is not yet implemented
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment