Commit 2eaa8575 authored by David S. Miller's avatar David S. Miller

Merge branch 'net-tls-fix-scatter-gather-list-issues'

Jakub Kicinski says:

====================
net: tls: fix scatter-gather list issues

This series kicked of by a syzbot report fixes three issues around
scatter gather handling in the TLS code. First patch fixes a use-
-after-free situation which may occur if record was freed on error.
This could have already happened in BPF paths, and patch 2 now makes
the same condition occur in non-BPF code.

Patch 2 fixes the problem spotted by syzbot. If encryption failed
we have to clean the end markings from scatter gather list. As
suggested by John the patch frees the record entirely and caller
may retry copying data from user space buffer again.

Third patch fixes a bug in the TLS 1.3 code spotted while working
on patch 2. TLS 1.3 may effectively overflow the SG list which
leads to the BUG() in sg_page() being triggered.

Patch 4 adds a test case which triggers this bug reliably.

Next two patches are small cleanups of dead code and code which
makes dangerous assumptions.

Last but not least two minor improvements to the sockmap tests.

Tested:
 - bpf/test_sockmap
 - net/tls
 - syzbot repro (which used error injection, hence no direct
   selftest is added to preserve it).
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 81b6b964 e5dc9dd3
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <net/strparser.h> #include <net/strparser.h>
#define MAX_MSG_FRAGS MAX_SKB_FRAGS #define MAX_MSG_FRAGS MAX_SKB_FRAGS
#define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
enum __sk_action { enum __sk_action {
__SK_DROP = 0, __SK_DROP = 0,
...@@ -29,13 +30,15 @@ struct sk_msg_sg { ...@@ -29,13 +30,15 @@ struct sk_msg_sg {
u32 size; u32 size;
u32 copybreak; u32 copybreak;
unsigned long copy; unsigned long copy;
/* The extra element is used for chaining the front and sections when /* The extra two elements:
* the list becomes partitioned (e.g. end < start). The crypto APIs * 1) used for chaining the front and sections when the list becomes
* require the chaining. * partitioned (e.g. end < start). The crypto APIs require the
* chaining;
* 2) to chain tailer SG entries after the message.
*/ */
struct scatterlist data[MAX_MSG_FRAGS + 1]; struct scatterlist data[MAX_MSG_FRAGS + 2];
}; };
static_assert(BITS_PER_LONG >= MAX_MSG_FRAGS); static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS);
/* UAPI in filter.c depends on struct sk_msg_sg being first element. */ /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
struct sk_msg { struct sk_msg {
...@@ -142,13 +145,13 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes) ...@@ -142,13 +145,13 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
static inline u32 sk_msg_iter_dist(u32 start, u32 end) static inline u32 sk_msg_iter_dist(u32 start, u32 end)
{ {
return end >= start ? end - start : end + (MAX_MSG_FRAGS - start); return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
} }
#define sk_msg_iter_var_prev(var) \ #define sk_msg_iter_var_prev(var) \
do { \ do { \
if (var == 0) \ if (var == 0) \
var = MAX_MSG_FRAGS - 1; \ var = NR_MSG_FRAG_IDS - 1; \
else \ else \
var--; \ var--; \
} while (0) } while (0)
...@@ -156,7 +159,7 @@ static inline u32 sk_msg_iter_dist(u32 start, u32 end) ...@@ -156,7 +159,7 @@ static inline u32 sk_msg_iter_dist(u32 start, u32 end)
#define sk_msg_iter_var_next(var) \ #define sk_msg_iter_var_next(var) \
do { \ do { \
var++; \ var++; \
if (var == MAX_MSG_FRAGS) \ if (var == NR_MSG_FRAG_IDS) \
var = 0; \ var = 0; \
} while (0) } while (0)
...@@ -173,9 +176,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg) ...@@ -173,9 +176,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg)
static inline void sk_msg_init(struct sk_msg *msg) static inline void sk_msg_init(struct sk_msg *msg)
{ {
BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS); BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
memset(msg, 0, sizeof(*msg)); memset(msg, 0, sizeof(*msg));
sg_init_marker(msg->sg.data, MAX_MSG_FRAGS); sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
} }
static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src, static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
...@@ -196,14 +199,11 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src) ...@@ -196,14 +199,11 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
static inline bool sk_msg_full(const struct sk_msg *msg) static inline bool sk_msg_full(const struct sk_msg *msg)
{ {
return (msg->sg.end == msg->sg.start) && msg->sg.size; return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
} }
static inline u32 sk_msg_elem_used(const struct sk_msg *msg) static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
{ {
if (sk_msg_full(msg))
return MAX_MSG_FRAGS;
return sk_msg_iter_dist(msg->sg.start, msg->sg.end); return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
} }
......
...@@ -100,7 +100,6 @@ struct tls_rec { ...@@ -100,7 +100,6 @@ struct tls_rec {
struct list_head list; struct list_head list;
int tx_ready; int tx_ready;
int tx_flags; int tx_flags;
int inplace_crypto;
struct sk_msg msg_plaintext; struct sk_msg msg_plaintext;
struct sk_msg msg_encrypted; struct sk_msg msg_encrypted;
...@@ -377,7 +376,7 @@ int tls_push_sg(struct sock *sk, struct tls_context *ctx, ...@@ -377,7 +376,7 @@ int tls_push_sg(struct sock *sk, struct tls_context *ctx,
int flags); int flags);
int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
int flags); int flags);
bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx); void tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
static inline struct tls_msg *tls_msg(struct sk_buff *skb) static inline struct tls_msg *tls_msg(struct sk_buff *skb)
{ {
......
...@@ -2299,7 +2299,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, ...@@ -2299,7 +2299,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
WARN_ON_ONCE(last_sge == first_sge); WARN_ON_ONCE(last_sge == first_sge);
shift = last_sge > first_sge ? shift = last_sge > first_sge ?
last_sge - first_sge - 1 : last_sge - first_sge - 1 :
MAX_SKB_FRAGS - first_sge + last_sge - 1; NR_MSG_FRAG_IDS - first_sge + last_sge - 1;
if (!shift) if (!shift)
goto out; goto out;
...@@ -2308,8 +2308,8 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, ...@@ -2308,8 +2308,8 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
do { do {
u32 move_from; u32 move_from;
if (i + shift >= MAX_MSG_FRAGS) if (i + shift >= NR_MSG_FRAG_IDS)
move_from = i + shift - MAX_MSG_FRAGS; move_from = i + shift - NR_MSG_FRAG_IDS;
else else
move_from = i + shift; move_from = i + shift;
if (move_from == msg->sg.end) if (move_from == msg->sg.end)
...@@ -2323,7 +2323,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, ...@@ -2323,7 +2323,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
} while (1); } while (1);
msg->sg.end = msg->sg.end - shift > msg->sg.end ? msg->sg.end = msg->sg.end - shift > msg->sg.end ?
msg->sg.end - shift + MAX_MSG_FRAGS : msg->sg.end - shift + NR_MSG_FRAG_IDS :
msg->sg.end - shift; msg->sg.end - shift;
out: out:
msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset; msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset;
......
...@@ -421,7 +421,7 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) ...@@ -421,7 +421,7 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
copied = skb->len; copied = skb->len;
msg->sg.start = 0; msg->sg.start = 0;
msg->sg.size = copied; msg->sg.size = copied;
msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge; msg->sg.end = num_sge;
msg->skb = skb; msg->skb = skb;
sk_psock_queue_msg(psock, msg); sk_psock_queue_msg(psock, msg);
......
...@@ -301,7 +301,7 @@ EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); ...@@ -301,7 +301,7 @@ EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
struct sk_msg *msg, int *copied, int flags) struct sk_msg *msg, int *copied, int flags)
{ {
bool cork = false, enospc = msg->sg.start == msg->sg.end; bool cork = false, enospc = sk_msg_full(msg);
struct sock *sk_redir; struct sock *sk_redir;
u32 tosend, delta = 0; u32 tosend, delta = 0;
int ret; int ret;
......
...@@ -209,24 +209,15 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, ...@@ -209,24 +209,15 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
return tls_push_sg(sk, ctx, sg, offset, flags); return tls_push_sg(sk, ctx, sg, offset, flags);
} }
bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx) void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
{ {
struct scatterlist *sg; struct scatterlist *sg;
sg = ctx->partially_sent_record; for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
if (!sg)
return false;
while (1) {
put_page(sg_page(sg)); put_page(sg_page(sg));
sk_mem_uncharge(sk, sg->length); sk_mem_uncharge(sk, sg->length);
if (sg_is_last(sg))
break;
sg++;
} }
ctx->partially_sent_record = NULL; ctx->partially_sent_record = NULL;
return true;
} }
static void tls_write_space(struct sock *sk) static void tls_write_space(struct sock *sk)
......
...@@ -710,8 +710,7 @@ static int tls_push_record(struct sock *sk, int flags, ...@@ -710,8 +710,7 @@ static int tls_push_record(struct sock *sk, int flags,
} }
i = msg_pl->sg.start; i = msg_pl->sg.start;
sg_chain(rec->sg_aead_in, 2, rec->inplace_crypto ? sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
&msg_en->sg.data[i] : &msg_pl->sg.data[i]);
i = msg_en->sg.end; i = msg_en->sg.end;
sk_msg_iter_var_prev(i); sk_msg_iter_var_prev(i);
...@@ -771,8 +770,14 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, ...@@ -771,8 +770,14 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
policy = !(flags & MSG_SENDPAGE_NOPOLICY); policy = !(flags & MSG_SENDPAGE_NOPOLICY);
psock = sk_psock_get(sk); psock = sk_psock_get(sk);
if (!psock || !policy) if (!psock || !policy) {
return tls_push_record(sk, flags, record_type); err = tls_push_record(sk, flags, record_type);
if (err) {
*copied -= sk_msg_free(sk, msg);
tls_free_open_rec(sk);
}
return err;
}
more_data: more_data:
enospc = sk_msg_full(msg); enospc = sk_msg_full(msg);
if (psock->eval == __SK_NONE) { if (psock->eval == __SK_NONE) {
...@@ -970,8 +975,6 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -970,8 +975,6 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
if (ret) if (ret)
goto fallback_to_reg_send; goto fallback_to_reg_send;
rec->inplace_crypto = 0;
num_zc++; num_zc++;
copied += try_to_copy; copied += try_to_copy;
...@@ -984,7 +987,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -984,7 +987,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
num_async++; num_async++;
else if (ret == -ENOMEM) else if (ret == -ENOMEM)
goto wait_for_memory; goto wait_for_memory;
else if (ret == -ENOSPC) else if (ctx->open_rec && ret == -ENOSPC)
goto rollback_iter; goto rollback_iter;
else if (ret != -EAGAIN) else if (ret != -EAGAIN)
goto send_end; goto send_end;
...@@ -1053,11 +1056,12 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -1053,11 +1056,12 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
ret = sk_stream_wait_memory(sk, &timeo); ret = sk_stream_wait_memory(sk, &timeo);
if (ret) { if (ret) {
trim_sgl: trim_sgl:
tls_trim_both_msgs(sk, orig_size); if (ctx->open_rec)
tls_trim_both_msgs(sk, orig_size);
goto send_end; goto send_end;
} }
if (msg_en->sg.size < required_size) if (ctx->open_rec && msg_en->sg.size < required_size)
goto alloc_encrypted; goto alloc_encrypted;
} }
...@@ -1169,7 +1173,6 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page, ...@@ -1169,7 +1173,6 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
tls_ctx->pending_open_record_frags = true; tls_ctx->pending_open_record_frags = true;
if (full_record || eor || sk_msg_full(msg_pl)) { if (full_record || eor || sk_msg_full(msg_pl)) {
rec->inplace_crypto = 0;
ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
record_type, &copied, flags); record_type, &copied, flags);
if (ret) { if (ret) {
...@@ -1190,11 +1193,13 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page, ...@@ -1190,11 +1193,13 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
wait_for_memory: wait_for_memory:
ret = sk_stream_wait_memory(sk, &timeo); ret = sk_stream_wait_memory(sk, &timeo);
if (ret) { if (ret) {
tls_trim_both_msgs(sk, msg_pl->sg.size); if (ctx->open_rec)
tls_trim_both_msgs(sk, msg_pl->sg.size);
goto sendpage_end; goto sendpage_end;
} }
goto alloc_payload; if (ctx->open_rec)
goto alloc_payload;
} }
if (num_async) { if (num_async) {
...@@ -2084,7 +2089,8 @@ void tls_sw_release_resources_tx(struct sock *sk) ...@@ -2084,7 +2089,8 @@ void tls_sw_release_resources_tx(struct sock *sk)
/* Free up un-sent records in tx_list. First, free /* Free up un-sent records in tx_list. First, free
* the partially sent record if any at head of tx_list. * the partially sent record if any at head of tx_list.
*/ */
if (tls_free_partial_record(sk, tls_ctx)) { if (tls_ctx->partially_sent_record) {
tls_free_partial_record(sk, tls_ctx);
rec = list_first_entry(&ctx->tx_list, rec = list_first_entry(&ctx->tx_list,
struct tls_rec, list); struct tls_rec, list);
list_del(&rec->list); list_del(&rec->list);
......
...@@ -240,14 +240,14 @@ static int sockmap_init_sockets(int verbose) ...@@ -240,14 +240,14 @@ static int sockmap_init_sockets(int verbose)
addr.sin_port = htons(S1_PORT); addr.sin_port = htons(S1_PORT);
err = bind(s1, (struct sockaddr *)&addr, sizeof(addr)); err = bind(s1, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0) { if (err < 0) {
perror("bind s1 failed()\n"); perror("bind s1 failed()");
return errno; return errno;
} }
addr.sin_port = htons(S2_PORT); addr.sin_port = htons(S2_PORT);
err = bind(s2, (struct sockaddr *)&addr, sizeof(addr)); err = bind(s2, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0) { if (err < 0) {
perror("bind s2 failed()\n"); perror("bind s2 failed()");
return errno; return errno;
} }
...@@ -255,14 +255,14 @@ static int sockmap_init_sockets(int verbose) ...@@ -255,14 +255,14 @@ static int sockmap_init_sockets(int verbose)
addr.sin_port = htons(S1_PORT); addr.sin_port = htons(S1_PORT);
err = listen(s1, 32); err = listen(s1, 32);
if (err < 0) { if (err < 0) {
perror("listen s1 failed()\n"); perror("listen s1 failed()");
return errno; return errno;
} }
addr.sin_port = htons(S2_PORT); addr.sin_port = htons(S2_PORT);
err = listen(s2, 32); err = listen(s2, 32);
if (err < 0) { if (err < 0) {
perror("listen s1 failed()\n"); perror("listen s1 failed()");
return errno; return errno;
} }
...@@ -270,14 +270,14 @@ static int sockmap_init_sockets(int verbose) ...@@ -270,14 +270,14 @@ static int sockmap_init_sockets(int verbose)
addr.sin_port = htons(S1_PORT); addr.sin_port = htons(S1_PORT);
err = connect(c1, (struct sockaddr *)&addr, sizeof(addr)); err = connect(c1, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0 && errno != EINPROGRESS) { if (err < 0 && errno != EINPROGRESS) {
perror("connect c1 failed()\n"); perror("connect c1 failed()");
return errno; return errno;
} }
addr.sin_port = htons(S2_PORT); addr.sin_port = htons(S2_PORT);
err = connect(c2, (struct sockaddr *)&addr, sizeof(addr)); err = connect(c2, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0 && errno != EINPROGRESS) { if (err < 0 && errno != EINPROGRESS) {
perror("connect c2 failed()\n"); perror("connect c2 failed()");
return errno; return errno;
} else if (err < 0) { } else if (err < 0) {
err = 0; err = 0;
...@@ -286,13 +286,13 @@ static int sockmap_init_sockets(int verbose) ...@@ -286,13 +286,13 @@ static int sockmap_init_sockets(int verbose)
/* Accept Connecrtions */ /* Accept Connecrtions */
p1 = accept(s1, NULL, NULL); p1 = accept(s1, NULL, NULL);
if (p1 < 0) { if (p1 < 0) {
perror("accept s1 failed()\n"); perror("accept s1 failed()");
return errno; return errno;
} }
p2 = accept(s2, NULL, NULL); p2 = accept(s2, NULL, NULL);
if (p2 < 0) { if (p2 < 0) {
perror("accept s1 failed()\n"); perror("accept s1 failed()");
return errno; return errno;
} }
...@@ -332,6 +332,10 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt, ...@@ -332,6 +332,10 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt,
int i, fp; int i, fp;
file = fopen(".sendpage_tst.tmp", "w+"); file = fopen(".sendpage_tst.tmp", "w+");
if (!file) {
perror("create file for sendpage");
return 1;
}
for (i = 0; i < iov_length * cnt; i++, k++) for (i = 0; i < iov_length * cnt; i++, k++)
fwrite(&k, sizeof(char), 1, file); fwrite(&k, sizeof(char), 1, file);
fflush(file); fflush(file);
...@@ -339,12 +343,17 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt, ...@@ -339,12 +343,17 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt,
fclose(file); fclose(file);
fp = open(".sendpage_tst.tmp", O_RDONLY); fp = open(".sendpage_tst.tmp", O_RDONLY);
if (fp < 0) {
perror("reopen file for sendpage");
return 1;
}
clock_gettime(CLOCK_MONOTONIC, &s->start); clock_gettime(CLOCK_MONOTONIC, &s->start);
for (i = 0; i < cnt; i++) { for (i = 0; i < cnt; i++) {
int sent = sendfile(fd, fp, NULL, iov_length); int sent = sendfile(fd, fp, NULL, iov_length);
if (!drop && sent < 0) { if (!drop && sent < 0) {
perror("send loop error:"); perror("send loop error");
close(fp); close(fp);
return sent; return sent;
} else if (drop && sent >= 0) { } else if (drop && sent >= 0) {
...@@ -463,7 +472,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, ...@@ -463,7 +472,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
int sent = sendmsg(fd, &msg, flags); int sent = sendmsg(fd, &msg, flags);
if (!drop && sent < 0) { if (!drop && sent < 0) {
perror("send loop error:"); perror("send loop error");
goto out_errno; goto out_errno;
} else if (drop && sent >= 0) { } else if (drop && sent >= 0) {
printf("send loop error expected: %i\n", sent); printf("send loop error expected: %i\n", sent);
...@@ -499,7 +508,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, ...@@ -499,7 +508,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
total_bytes -= txmsg_pop_total; total_bytes -= txmsg_pop_total;
err = clock_gettime(CLOCK_MONOTONIC, &s->start); err = clock_gettime(CLOCK_MONOTONIC, &s->start);
if (err < 0) if (err < 0)
perror("recv start time: "); perror("recv start time");
while (s->bytes_recvd < total_bytes) { while (s->bytes_recvd < total_bytes) {
if (txmsg_cork) { if (txmsg_cork) {
timeout.tv_sec = 0; timeout.tv_sec = 0;
...@@ -543,7 +552,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, ...@@ -543,7 +552,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
if (recv < 0) { if (recv < 0) {
if (errno != EWOULDBLOCK) { if (errno != EWOULDBLOCK) {
clock_gettime(CLOCK_MONOTONIC, &s->end); clock_gettime(CLOCK_MONOTONIC, &s->end);
perror("recv failed()\n"); perror("recv failed()");
goto out_errno; goto out_errno;
} }
} }
...@@ -557,7 +566,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, ...@@ -557,7 +566,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
errno = msg_verify_data(&msg, recv, chunk_sz); errno = msg_verify_data(&msg, recv, chunk_sz);
if (errno) { if (errno) {
perror("data verify msg failed\n"); perror("data verify msg failed");
goto out_errno; goto out_errno;
} }
if (recvp) { if (recvp) {
...@@ -565,7 +574,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, ...@@ -565,7 +574,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
recvp, recvp,
chunk_sz); chunk_sz);
if (errno) { if (errno) {
perror("data verify msg_peek failed\n"); perror("data verify msg_peek failed");
goto out_errno; goto out_errno;
} }
} }
...@@ -654,7 +663,7 @@ static int sendmsg_test(struct sockmap_options *opt) ...@@ -654,7 +663,7 @@ static int sendmsg_test(struct sockmap_options *opt)
err = 0; err = 0;
exit(err ? 1 : 0); exit(err ? 1 : 0);
} else if (rxpid == -1) { } else if (rxpid == -1) {
perror("msg_loop_rx: "); perror("msg_loop_rx");
return errno; return errno;
} }
...@@ -681,7 +690,7 @@ static int sendmsg_test(struct sockmap_options *opt) ...@@ -681,7 +690,7 @@ static int sendmsg_test(struct sockmap_options *opt)
s.bytes_recvd, recvd_Bps, recvd_Bps/giga); s.bytes_recvd, recvd_Bps, recvd_Bps/giga);
exit(err ? 1 : 0); exit(err ? 1 : 0);
} else if (txpid == -1) { } else if (txpid == -1) {
perror("msg_loop_tx: "); perror("msg_loop_tx");
return errno; return errno;
} }
...@@ -715,7 +724,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt) ...@@ -715,7 +724,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt)
/* Ping/Pong data from client to server */ /* Ping/Pong data from client to server */
sc = send(c1, buf, sizeof(buf), 0); sc = send(c1, buf, sizeof(buf), 0);
if (sc < 0) { if (sc < 0) {
perror("send failed()\n"); perror("send failed()");
return sc; return sc;
} }
...@@ -748,7 +757,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt) ...@@ -748,7 +757,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt)
rc = recv(i, buf, sizeof(buf), 0); rc = recv(i, buf, sizeof(buf), 0);
if (rc < 0) { if (rc < 0) {
if (errno != EWOULDBLOCK) { if (errno != EWOULDBLOCK) {
perror("recv failed()\n"); perror("recv failed()");
return rc; return rc;
} }
} }
...@@ -760,7 +769,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt) ...@@ -760,7 +769,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt)
sc = send(i, buf, rc, 0); sc = send(i, buf, rc, 0);
if (sc < 0) { if (sc < 0) {
perror("send failed()\n"); perror("send failed()");
return sc; return sc;
} }
} }
......
...@@ -45,7 +45,7 @@ static int get_stats(int fd, __u16 count, __u32 raddr) ...@@ -45,7 +45,7 @@ static int get_stats(int fd, __u16 count, __u32 raddr)
printf("\nXDP RTT data:\n"); printf("\nXDP RTT data:\n");
if (bpf_map_lookup_elem(fd, &raddr, &pinginfo)) { if (bpf_map_lookup_elem(fd, &raddr, &pinginfo)) {
perror("bpf_map_lookup elem: "); perror("bpf_map_lookup elem");
return 1; return 1;
} }
......
...@@ -268,6 +268,38 @@ TEST_F(tls, sendmsg_single) ...@@ -268,6 +268,38 @@ TEST_F(tls, sendmsg_single)
EXPECT_EQ(memcmp(buf, test_str, send_len), 0); EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
} }
#define MAX_FRAGS 64
#define SEND_LEN 13
TEST_F(tls, sendmsg_fragmented)
{
char const *test_str = "test_sendmsg";
char buf[SEND_LEN * MAX_FRAGS];
struct iovec vec[MAX_FRAGS];
struct msghdr msg;
int i, frags;
for (frags = 1; frags <= MAX_FRAGS; frags++) {
for (i = 0; i < frags; i++) {
vec[i].iov_base = (char *)test_str;
vec[i].iov_len = SEND_LEN;
}
memset(&msg, 0, sizeof(struct msghdr));
msg.msg_iov = vec;
msg.msg_iovlen = frags;
EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
SEND_LEN * frags);
for (i = 0; i < frags; i++)
EXPECT_EQ(memcmp(buf + SEND_LEN * i,
test_str, SEND_LEN), 0);
}
}
#undef MAX_FRAGS
#undef SEND_LEN
TEST_F(tls, sendmsg_large) TEST_F(tls, sendmsg_large)
{ {
void *mem = malloc(16384); void *mem = malloc(16384);
...@@ -694,6 +726,34 @@ TEST_F(tls, recv_lowat) ...@@ -694,6 +726,34 @@ TEST_F(tls, recv_lowat)
EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0); EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
} }
TEST_F(tls, recv_rcvbuf)
{
char send_mem[4096];
char recv_mem[4096];
int rcv_buf = 1024;
memset(send_mem, 0x1c, sizeof(send_mem));
EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVBUF,
&rcv_buf, sizeof(rcv_buf)), 0);
EXPECT_EQ(send(self->fd, send_mem, 512, 0), 512);
memset(recv_mem, 0, sizeof(recv_mem));
EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), 512);
EXPECT_EQ(memcmp(send_mem, recv_mem, 512), 0);
if (self->notls)
return;
EXPECT_EQ(send(self->fd, send_mem, 4096, 0), 4096);
memset(recv_mem, 0, sizeof(recv_mem));
EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), -1);
EXPECT_EQ(errno, EMSGSIZE);
EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), -1);
EXPECT_EQ(errno, EMSGSIZE);
}
TEST_F(tls, bidir) TEST_F(tls, bidir)
{ {
char const *test_str = "test_read"; char const *test_str = "test_read";
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment