Commit a0019cd7 authored by Jiri Olsa's avatar Jiri Olsa Committed by Alexei Starovoitov

lib/sort: Add priv pointer to swap function

Adding support to have priv pointer in swap callback function.

Following the initial change on cmp callback functions [1]
and adding SWAP_WRAPPER macro to identify sort call of sort_r.
Signed-off-by: default avatarJiri Olsa <jolsa@kernel.org>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Reviewed-by: default avatarMasami Hiramatsu <mhiramat@kernel.org>
Link: https://lore.kernel.org/bpf/20220316122419.933957-2-jolsa@kernel.org

[1] 4333fb96 ("media: lib/sort.c: implement sort() variant taking context argument")
parent 245d9496
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
void sort_r(void *base, size_t num, size_t size, void sort_r(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func, cmp_r_func_t cmp_func,
swap_func_t swap_func, swap_r_func_t swap_func,
const void *priv); const void *priv);
void sort(void *base, size_t num, size_t size, void sort(void *base, size_t num, size_t size,
......
...@@ -226,6 +226,7 @@ struct callback_head { ...@@ -226,6 +226,7 @@ struct callback_head {
typedef void (*rcu_callback_t)(struct rcu_head *head); typedef void (*rcu_callback_t)(struct rcu_head *head);
typedef void (*call_rcu_func_t)(struct rcu_head *head, rcu_callback_t func); typedef void (*call_rcu_func_t)(struct rcu_head *head, rcu_callback_t func);
typedef void (*swap_r_func_t)(void *a, void *b, int size, const void *priv);
typedef void (*swap_func_t)(void *a, void *b, int size); typedef void (*swap_func_t)(void *a, void *b, int size);
typedef int (*cmp_r_func_t)(const void *a, const void *b, const void *priv); typedef int (*cmp_r_func_t)(const void *a, const void *b, const void *priv);
......
...@@ -122,16 +122,27 @@ static void swap_bytes(void *a, void *b, size_t n) ...@@ -122,16 +122,27 @@ static void swap_bytes(void *a, void *b, size_t n)
* a pointer, but small integers make for the smallest compare * a pointer, but small integers make for the smallest compare
* instructions. * instructions.
*/ */
#define SWAP_WORDS_64 (swap_func_t)0 #define SWAP_WORDS_64 (swap_r_func_t)0
#define SWAP_WORDS_32 (swap_func_t)1 #define SWAP_WORDS_32 (swap_r_func_t)1
#define SWAP_BYTES (swap_func_t)2 #define SWAP_BYTES (swap_r_func_t)2
#define SWAP_WRAPPER (swap_r_func_t)3
struct wrapper {
cmp_func_t cmp;
swap_func_t swap;
};
/* /*
* The function pointer is last to make tail calls most efficient if the * The function pointer is last to make tail calls most efficient if the
* compiler decides not to inline this function. * compiler decides not to inline this function.
*/ */
static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func) static void do_swap(void *a, void *b, size_t size, swap_r_func_t swap_func, const void *priv)
{ {
if (swap_func == SWAP_WRAPPER) {
((const struct wrapper *)priv)->swap(a, b, (int)size);
return;
}
if (swap_func == SWAP_WORDS_64) if (swap_func == SWAP_WORDS_64)
swap_words_64(a, b, size); swap_words_64(a, b, size);
else if (swap_func == SWAP_WORDS_32) else if (swap_func == SWAP_WORDS_32)
...@@ -139,7 +150,7 @@ static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func) ...@@ -139,7 +150,7 @@ static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func)
else if (swap_func == SWAP_BYTES) else if (swap_func == SWAP_BYTES)
swap_bytes(a, b, size); swap_bytes(a, b, size);
else else
swap_func(a, b, (int)size); swap_func(a, b, (int)size, priv);
} }
#define _CMP_WRAPPER ((cmp_r_func_t)0L) #define _CMP_WRAPPER ((cmp_r_func_t)0L)
...@@ -147,7 +158,7 @@ static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func) ...@@ -147,7 +158,7 @@ static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func)
static int do_cmp(const void *a, const void *b, cmp_r_func_t cmp, const void *priv) static int do_cmp(const void *a, const void *b, cmp_r_func_t cmp, const void *priv)
{ {
if (cmp == _CMP_WRAPPER) if (cmp == _CMP_WRAPPER)
return ((cmp_func_t)(priv))(a, b); return ((const struct wrapper *)priv)->cmp(a, b);
return cmp(a, b, priv); return cmp(a, b, priv);
} }
...@@ -198,7 +209,7 @@ static size_t parent(size_t i, unsigned int lsbit, size_t size) ...@@ -198,7 +209,7 @@ static size_t parent(size_t i, unsigned int lsbit, size_t size)
*/ */
void sort_r(void *base, size_t num, size_t size, void sort_r(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func, cmp_r_func_t cmp_func,
swap_func_t swap_func, swap_r_func_t swap_func,
const void *priv) const void *priv)
{ {
/* pre-scale counters for performance */ /* pre-scale counters for performance */
...@@ -208,6 +219,10 @@ void sort_r(void *base, size_t num, size_t size, ...@@ -208,6 +219,10 @@ void sort_r(void *base, size_t num, size_t size,
if (!a) /* num < 2 || size == 0 */ if (!a) /* num < 2 || size == 0 */
return; return;
/* called from 'sort' without swap function, let's pick the default */
if (swap_func == SWAP_WRAPPER && !((struct wrapper *)priv)->swap)
swap_func = NULL;
if (!swap_func) { if (!swap_func) {
if (is_aligned(base, size, 8)) if (is_aligned(base, size, 8))
swap_func = SWAP_WORDS_64; swap_func = SWAP_WORDS_64;
...@@ -230,7 +245,7 @@ void sort_r(void *base, size_t num, size_t size, ...@@ -230,7 +245,7 @@ void sort_r(void *base, size_t num, size_t size,
if (a) /* Building heap: sift down --a */ if (a) /* Building heap: sift down --a */
a -= size; a -= size;
else if (n -= size) /* Sorting: Extract root to --n */ else if (n -= size) /* Sorting: Extract root to --n */
do_swap(base, base + n, size, swap_func); do_swap(base, base + n, size, swap_func, priv);
else /* Sort complete */ else /* Sort complete */
break; break;
...@@ -257,7 +272,7 @@ void sort_r(void *base, size_t num, size_t size, ...@@ -257,7 +272,7 @@ void sort_r(void *base, size_t num, size_t size,
c = b; /* Where "a" belongs */ c = b; /* Where "a" belongs */
while (b != a) { /* Shift it into place */ while (b != a) { /* Shift it into place */
b = parent(b, lsbit, size); b = parent(b, lsbit, size);
do_swap(base + b, base + c, size, swap_func); do_swap(base + b, base + c, size, swap_func, priv);
} }
} }
} }
...@@ -267,6 +282,11 @@ void sort(void *base, size_t num, size_t size, ...@@ -267,6 +282,11 @@ void sort(void *base, size_t num, size_t size,
cmp_func_t cmp_func, cmp_func_t cmp_func,
swap_func_t swap_func) swap_func_t swap_func)
{ {
return sort_r(base, num, size, _CMP_WRAPPER, swap_func, cmp_func); struct wrapper w = {
.cmp = cmp_func,
.swap = swap_func,
};
return sort_r(base, num, size, _CMP_WRAPPER, SWAP_WRAPPER, &w);
} }
EXPORT_SYMBOL(sort); EXPORT_SYMBOL(sort);
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment