Commit 44d520eb authored by Daniel Borkmann's avatar Daniel Borkmann

Merge branch 'bpf-sk-msg-peek'

John Fastabend says:

====================
This adds support for the MSG_PEEK flag when redirecting into
an ingress psock sk_msg queue.

The first patch adds some base support to the helpers, then the
feature, and finally we add an option for the test suite to do
a duplicate MSG_PEEK call on every recv to test the feature.

With duplicate MSG_PEEK call all tests continue to PASS.
====================
Acked-by: default avatarAlexei Starovoitov <ast@kernel.org>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parents 3f4c3127 753fb2ee
...@@ -187,18 +187,21 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src) ...@@ -187,18 +187,21 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
sk_msg_init(src); sk_msg_init(src);
} }
static inline bool sk_msg_full(const struct sk_msg *msg)
{
return (msg->sg.end == msg->sg.start) && msg->sg.size;
}
static inline u32 sk_msg_elem_used(const struct sk_msg *msg) static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
{ {
if (sk_msg_full(msg))
return MAX_MSG_FRAGS;
return msg->sg.end >= msg->sg.start ? return msg->sg.end >= msg->sg.start ?
msg->sg.end - msg->sg.start : msg->sg.end - msg->sg.start :
msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start); msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start);
} }
static inline bool sk_msg_full(const struct sk_msg *msg)
{
return (msg->sg.end == msg->sg.start) && msg->sg.size;
}
static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which) static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
{ {
return &msg->sg.data[which]; return &msg->sg.data[which];
......
...@@ -2089,7 +2089,7 @@ int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes, ...@@ -2089,7 +2089,7 @@ int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len); int nonblock, int flags, int *addr_len);
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
struct msghdr *msg, int len); struct msghdr *msg, int len, int flags);
/* Call BPF_SOCK_OPS program that returns an int. If the return value /* Call BPF_SOCK_OPS program that returns an int. If the return value
* is < 0, then the BPF op failed (for example if the loaded BPF * is < 0, then the BPF op failed (for example if the loaded BPF
......
...@@ -39,17 +39,19 @@ static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, ...@@ -39,17 +39,19 @@ static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
} }
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
struct msghdr *msg, int len) struct msghdr *msg, int len, int flags)
{ {
struct iov_iter *iter = &msg->msg_iter; struct iov_iter *iter = &msg->msg_iter;
int peek = flags & MSG_PEEK;
int i, ret, copied = 0; int i, ret, copied = 0;
struct sk_msg *msg_rx;
msg_rx = list_first_entry_or_null(&psock->ingress_msg,
struct sk_msg, list);
while (copied != len) { while (copied != len) {
struct scatterlist *sge; struct scatterlist *sge;
struct sk_msg *msg_rx;
msg_rx = list_first_entry_or_null(&psock->ingress_msg,
struct sk_msg, list);
if (unlikely(!msg_rx)) if (unlikely(!msg_rx))
break; break;
...@@ -70,22 +72,30 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, ...@@ -70,22 +72,30 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
} }
copied += copy; copied += copy;
sge->offset += copy; if (likely(!peek)) {
sge->length -= copy; sge->offset += copy;
sk_mem_uncharge(sk, copy); sge->length -= copy;
msg_rx->sg.size -= copy; sk_mem_uncharge(sk, copy);
if (!sge->length) { msg_rx->sg.size -= copy;
i++;
if (i == MAX_SKB_FRAGS) if (!sge->length) {
i = 0; sk_msg_iter_var_next(i);
if (!msg_rx->skb) if (!msg_rx->skb)
put_page(page); put_page(page);
}
} else {
sk_msg_iter_var_next(i);
} }
if (copied == len) if (copied == len)
break; break;
} while (i != msg_rx->sg.end); } while (i != msg_rx->sg.end);
if (unlikely(peek)) {
msg_rx = list_next_entry(msg_rx, list);
continue;
}
msg_rx->sg.start = i; msg_rx->sg.start = i;
if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
list_del(&msg_rx->list); list_del(&msg_rx->list);
...@@ -93,6 +103,8 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, ...@@ -93,6 +103,8 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
consume_skb(msg_rx->skb); consume_skb(msg_rx->skb);
kfree(msg_rx); kfree(msg_rx);
} }
msg_rx = list_first_entry_or_null(&psock->ingress_msg,
struct sk_msg, list);
} }
return copied; return copied;
...@@ -115,7 +127,7 @@ int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -115,7 +127,7 @@ int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
lock_sock(sk); lock_sock(sk);
msg_bytes_ready: msg_bytes_ready:
copied = __tcp_bpf_recvmsg(sk, psock, msg, len); copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
if (!copied) { if (!copied) {
int data, err = 0; int data, err = 0;
long timeo; long timeo;
......
...@@ -1478,7 +1478,8 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1478,7 +1478,8 @@ int tls_sw_recvmsg(struct sock *sk,
skb = tls_wait_data(sk, psock, flags, timeo, &err); skb = tls_wait_data(sk, psock, flags, timeo, &err);
if (!skb) { if (!skb) {
if (psock) { if (psock) {
int ret = __tcp_bpf_recvmsg(sk, psock, msg, len); int ret = __tcp_bpf_recvmsg(sk, psock,
msg, len, flags);
if (ret > 0) { if (ret > 0) {
copied += ret; copied += ret;
......
...@@ -80,6 +80,7 @@ int txmsg_end; ...@@ -80,6 +80,7 @@ int txmsg_end;
int txmsg_ingress; int txmsg_ingress;
int txmsg_skb; int txmsg_skb;
int ktls; int ktls;
int peek_flag;
static const struct option long_options[] = { static const struct option long_options[] = {
{"help", no_argument, NULL, 'h' }, {"help", no_argument, NULL, 'h' },
...@@ -102,6 +103,7 @@ static const struct option long_options[] = { ...@@ -102,6 +103,7 @@ static const struct option long_options[] = {
{"txmsg_ingress", no_argument, &txmsg_ingress, 1 }, {"txmsg_ingress", no_argument, &txmsg_ingress, 1 },
{"txmsg_skb", no_argument, &txmsg_skb, 1 }, {"txmsg_skb", no_argument, &txmsg_skb, 1 },
{"ktls", no_argument, &ktls, 1 }, {"ktls", no_argument, &ktls, 1 },
{"peek", no_argument, &peek_flag, 1 },
{0, 0, NULL, 0 } {0, 0, NULL, 0 }
}; };
...@@ -352,33 +354,40 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt, ...@@ -352,33 +354,40 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt,
return 0; return 0;
} }
static int msg_loop(int fd, int iov_count, int iov_length, int cnt, static void msg_free_iov(struct msghdr *msg)
struct msg_stats *s, bool tx,
struct sockmap_options *opt)
{ {
struct msghdr msg = {0}; int i;
int err, i, flags = MSG_NOSIGNAL;
for (i = 0; i < msg->msg_iovlen; i++)
free(msg->msg_iov[i].iov_base);
free(msg->msg_iov);
msg->msg_iov = NULL;
msg->msg_iovlen = 0;
}
static int msg_alloc_iov(struct msghdr *msg,
int iov_count, int iov_length,
bool data, bool xmit)
{
unsigned char k = 0;
struct iovec *iov; struct iovec *iov;
unsigned char k; int i;
bool data_test = opt->data_test;
bool drop = opt->drop_expected;
iov = calloc(iov_count, sizeof(struct iovec)); iov = calloc(iov_count, sizeof(struct iovec));
if (!iov) if (!iov)
return errno; return errno;
k = 0;
for (i = 0; i < iov_count; i++) { for (i = 0; i < iov_count; i++) {
unsigned char *d = calloc(iov_length, sizeof(char)); unsigned char *d = calloc(iov_length, sizeof(char));
if (!d) { if (!d) {
fprintf(stderr, "iov_count %i/%i OOM\n", i, iov_count); fprintf(stderr, "iov_count %i/%i OOM\n", i, iov_count);
goto out_errno; goto unwind_iov;
} }
iov[i].iov_base = d; iov[i].iov_base = d;
iov[i].iov_len = iov_length; iov[i].iov_len = iov_length;
if (data_test && tx) { if (data && xmit) {
int j; int j;
for (j = 0; j < iov_length; j++) for (j = 0; j < iov_length; j++)
...@@ -386,9 +395,60 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, ...@@ -386,9 +395,60 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
} }
} }
msg.msg_iov = iov; msg->msg_iov = iov;
msg.msg_iovlen = iov_count; msg->msg_iovlen = iov_count;
k = 0;
return 0;
unwind_iov:
for (i--; i >= 0 ; i--)
free(msg->msg_iov[i].iov_base);
return -ENOMEM;
}
static int msg_verify_data(struct msghdr *msg, int size, int chunk_sz)
{
int i, j, bytes_cnt = 0;
unsigned char k = 0;
for (i = 0; i < msg->msg_iovlen; i++) {
unsigned char *d = msg->msg_iov[i].iov_base;
for (j = 0;
j < msg->msg_iov[i].iov_len && size; j++) {
if (d[j] != k++) {
fprintf(stderr,
"detected data corruption @iov[%i]:%i %02x != %02x, %02x ?= %02x\n",
i, j, d[j], k - 1, d[j+1], k);
return -EIO;
}
bytes_cnt++;
if (bytes_cnt == chunk_sz) {
k = 0;
bytes_cnt = 0;
}
size--;
}
}
return 0;
}
static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
struct msg_stats *s, bool tx,
struct sockmap_options *opt)
{
struct msghdr msg = {0}, msg_peek = {0};
int err, i, flags = MSG_NOSIGNAL;
bool drop = opt->drop_expected;
bool data = opt->data_test;
err = msg_alloc_iov(&msg, iov_count, iov_length, data, tx);
if (err)
goto out_errno;
if (peek_flag) {
err = msg_alloc_iov(&msg_peek, iov_count, iov_length, data, tx);
if (err)
goto out_errno;
}
if (tx) { if (tx) {
clock_gettime(CLOCK_MONOTONIC, &s->start); clock_gettime(CLOCK_MONOTONIC, &s->start);
...@@ -408,19 +468,12 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, ...@@ -408,19 +468,12 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
} }
clock_gettime(CLOCK_MONOTONIC, &s->end); clock_gettime(CLOCK_MONOTONIC, &s->end);
} else { } else {
int slct, recv, max_fd = fd; int slct, recvp = 0, recv, max_fd = fd;
int fd_flags = O_NONBLOCK; int fd_flags = O_NONBLOCK;
struct timeval timeout; struct timeval timeout;
float total_bytes; float total_bytes;
int bytes_cnt = 0;
int chunk_sz;
fd_set w; fd_set w;
if (opt->sendpage)
chunk_sz = iov_length * cnt;
else
chunk_sz = iov_length * iov_count;
fcntl(fd, fd_flags); fcntl(fd, fd_flags);
total_bytes = (float)iov_count * (float)iov_length * (float)cnt; total_bytes = (float)iov_count * (float)iov_length * (float)cnt;
err = clock_gettime(CLOCK_MONOTONIC, &s->start); err = clock_gettime(CLOCK_MONOTONIC, &s->start);
...@@ -452,6 +505,19 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, ...@@ -452,6 +505,19 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
goto out_errno; goto out_errno;
} }
errno = 0;
if (peek_flag) {
flags |= MSG_PEEK;
recvp = recvmsg(fd, &msg_peek, flags);
if (recvp < 0) {
if (errno != EWOULDBLOCK) {
clock_gettime(CLOCK_MONOTONIC, &s->end);
goto out_errno;
}
}
flags = 0;
}
recv = recvmsg(fd, &msg, flags); recv = recvmsg(fd, &msg, flags);
if (recv < 0) { if (recv < 0) {
if (errno != EWOULDBLOCK) { if (errno != EWOULDBLOCK) {
...@@ -463,27 +529,23 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, ...@@ -463,27 +529,23 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
s->bytes_recvd += recv; s->bytes_recvd += recv;
if (data_test) { if (data) {
int j; int chunk_sz = opt->sendpage ?
iov_length * cnt :
for (i = 0; i < msg.msg_iovlen; i++) { iov_length * iov_count;
unsigned char *d = iov[i].iov_base;
errno = msg_verify_data(&msg, recv, chunk_sz);
for (j = 0; if (errno) {
j < iov[i].iov_len && recv; j++) { perror("data verify msg failed\n");
if (d[j] != k++) { goto out_errno;
errno = -EIO; }
fprintf(stderr, if (recvp) {
"detected data corruption @iov[%i]:%i %02x != %02x, %02x ?= %02x\n", errno = msg_verify_data(&msg_peek,
i, j, d[j], k - 1, d[j+1], k); recvp,
goto out_errno; chunk_sz);
} if (errno) {
bytes_cnt++; perror("data verify msg_peek failed\n");
if (bytes_cnt == chunk_sz) { goto out_errno;
k = 0;
bytes_cnt = 0;
}
recv--;
} }
} }
} }
...@@ -491,14 +553,12 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, ...@@ -491,14 +553,12 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
clock_gettime(CLOCK_MONOTONIC, &s->end); clock_gettime(CLOCK_MONOTONIC, &s->end);
} }
for (i = 0; i < iov_count; i++) msg_free_iov(&msg);
free(iov[i].iov_base); msg_free_iov(&msg_peek);
free(iov); return err;
return 0;
out_errno: out_errno:
for (i = 0; i < iov_count; i++) msg_free_iov(&msg);
free(iov[i].iov_base); msg_free_iov(&msg_peek);
free(iov);
return errno; return errno;
} }
...@@ -565,9 +625,10 @@ static int sendmsg_test(struct sockmap_options *opt) ...@@ -565,9 +625,10 @@ static int sendmsg_test(struct sockmap_options *opt)
} }
if (opt->verbose) if (opt->verbose)
fprintf(stdout, fprintf(stdout,
"rx_sendmsg: TX: %zuB %fB/s %fGB/s RX: %zuB %fB/s %fGB/s\n", "rx_sendmsg: TX: %zuB %fB/s %fGB/s RX: %zuB %fB/s %fGB/s %s\n",
s.bytes_sent, sent_Bps, sent_Bps/giga, s.bytes_sent, sent_Bps, sent_Bps/giga,
s.bytes_recvd, recvd_Bps, recvd_Bps/giga); s.bytes_recvd, recvd_Bps, recvd_Bps/giga,
peek_flag ? "(peek_msg)" : "");
if (err && txmsg_cork) if (err && txmsg_cork)
err = 0; err = 0;
exit(err ? 1 : 0); exit(err ? 1 : 0);
...@@ -999,6 +1060,8 @@ static void test_options(char *options) ...@@ -999,6 +1060,8 @@ static void test_options(char *options)
strncat(options, "skb,", OPTSTRING); strncat(options, "skb,", OPTSTRING);
if (ktls) if (ktls)
strncat(options, "ktls,", OPTSTRING); strncat(options, "ktls,", OPTSTRING);
if (peek_flag)
strncat(options, "peek,", OPTSTRING);
} }
static int __test_exec(int cgrp, int test, struct sockmap_options *opt) static int __test_exec(int cgrp, int test, struct sockmap_options *opt)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment