Commit 37e26188 authored by Jakub Kicinski's avatar Jakub Kicinski

Merge branch 'tls-rx-follow-ups-to-rx-work'

Jakub Kicinski says:

====================
tls: rx: follow ups to rx work

A selection of unrelated changes. First some selftest polishing.
Next a change to rcvtimeo handling for locking based on an exchange
with Eric. Follow up to Paolo's comments from yesterday. Last but
not least a fix to a false positive warning, turns out I've been
testing with DEBUG_NET=n this whole time.
====================

Link: https://lore.kernel.org/r/20220727031524.358216-1-kuba@kernel.orgSigned-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents efe3e6b5 e20691fa
...@@ -480,7 +480,7 @@ void tls_strp_done(struct tls_strparser *strp) ...@@ -480,7 +480,7 @@ void tls_strp_done(struct tls_strparser *strp)
int __init tls_strp_dev_init(void) int __init tls_strp_dev_init(void)
{ {
tls_strp_wq = create_singlethread_workqueue("kstrp"); tls_strp_wq = create_workqueue("tls-strp");
if (unlikely(!tls_strp_wq)) if (unlikely(!tls_strp_wq))
return -ENOMEM; return -ENOMEM;
......
...@@ -1283,11 +1283,14 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, ...@@ -1283,11 +1283,14 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
static int static int
tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
bool released, long timeo) bool released)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
DEFINE_WAIT_FUNC(wait, woken_wake_function); DEFINE_WAIT_FUNC(wait, woken_wake_function);
long timeo;
timeo = sock_rcvtimeo(sk, nonblock);
while (!tls_strp_msg_ready(ctx)) { while (!tls_strp_msg_ready(ctx)) {
if (!sk_psock_queue_empty(psock)) if (!sk_psock_queue_empty(psock))
...@@ -1308,7 +1311,7 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, ...@@ -1308,7 +1311,7 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
if (sock_flag(sk, SOCK_DONE)) if (sock_flag(sk, SOCK_DONE))
return 0; return 0;
if (nonblock || !timeo) if (!timeo)
return -EAGAIN; return -EAGAIN;
released = true; released = true;
...@@ -1842,8 +1845,8 @@ tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, ...@@ -1842,8 +1845,8 @@ tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
return sk_flush_backlog(sk); return sk_flush_backlog(sk);
} }
static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
bool nonblock) bool nonblock)
{ {
long timeo; long timeo;
int err; int err;
...@@ -1874,7 +1877,7 @@ static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, ...@@ -1874,7 +1877,7 @@ static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
WRITE_ONCE(ctx->reader_present, 1); WRITE_ONCE(ctx->reader_present, 1);
return timeo; return 0;
err_unlock: err_unlock:
release_sock(sk); release_sock(sk);
...@@ -1913,8 +1916,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1913,8 +1916,7 @@ int tls_sw_recvmsg(struct sock *sk,
struct tls_msg *tlm; struct tls_msg *tlm;
ssize_t copied = 0; ssize_t copied = 0;
bool async = false; bool async = false;
int target, err = 0; int target, err;
long timeo;
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
bool is_peek = flags & MSG_PEEK; bool is_peek = flags & MSG_PEEK;
bool released = true; bool released = true;
...@@ -1925,9 +1927,9 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1925,9 +1927,9 @@ int tls_sw_recvmsg(struct sock *sk,
return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
psock = sk_psock_get(sk); psock = sk_psock_get(sk);
timeo = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT); err = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT);
if (timeo < 0) if (err < 0)
return timeo; return err;
bpf_strp_enabled = sk_psock_strp_enabled(psock); bpf_strp_enabled = sk_psock_strp_enabled(psock);
/* If crypto failed the connection is broken */ /* If crypto failed the connection is broken */
...@@ -1954,8 +1956,8 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1954,8 +1956,8 @@ int tls_sw_recvmsg(struct sock *sk,
struct tls_decrypt_arg darg; struct tls_decrypt_arg darg;
int to_decrypt, chunk; int to_decrypt, chunk;
err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, released, err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT,
timeo); released);
if (err <= 0) { if (err <= 0) {
if (psock) { if (psock) {
chunk = sk_msg_recvmsg(sk, psock, msg, len, chunk = sk_msg_recvmsg(sk, psock, msg, len,
...@@ -2024,7 +2026,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -2024,7 +2026,7 @@ int tls_sw_recvmsg(struct sock *sk,
bool partially_consumed = chunk > len; bool partially_consumed = chunk > len;
struct sk_buff *skb = darg.skb; struct sk_buff *skb = darg.skb;
DEBUG_NET_WARN_ON_ONCE(darg.skb == tls_strp_msg(ctx)); DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->strp.anchor);
if (async) { if (async) {
/* TLS 1.2-only, to_decrypt must be text len */ /* TLS 1.2-only, to_decrypt must be text len */
...@@ -2131,13 +2133,12 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2131,13 +2133,12 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct tls_msg *tlm; struct tls_msg *tlm;
struct sk_buff *skb; struct sk_buff *skb;
ssize_t copied = 0; ssize_t copied = 0;
int err = 0;
long timeo;
int chunk; int chunk;
int err;
timeo = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK); err = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK);
if (timeo < 0) if (err < 0)
return timeo; return err;
if (!skb_queue_empty(&ctx->rx_list)) { if (!skb_queue_empty(&ctx->rx_list)) {
skb = __skb_dequeue(&ctx->rx_list); skb = __skb_dequeue(&ctx->rx_list);
...@@ -2145,7 +2146,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2145,7 +2146,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct tls_decrypt_arg darg; struct tls_decrypt_arg darg;
err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK, err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
true, timeo); true);
if (err <= 0) if (err <= 0)
goto splice_read_end; goto splice_read_end;
......
...@@ -644,12 +644,14 @@ TEST_F(tls, splice_from_pipe2) ...@@ -644,12 +644,14 @@ TEST_F(tls, splice_from_pipe2)
int p2[2]; int p2[2];
int p[2]; int p[2];
memrnd(mem_send, sizeof(mem_send));
ASSERT_GE(pipe(p), 0); ASSERT_GE(pipe(p), 0);
ASSERT_GE(pipe(p2), 0); ASSERT_GE(pipe(p2), 0);
EXPECT_GE(write(p[1], mem_send, 8000), 0); EXPECT_EQ(write(p[1], mem_send, 8000), 8000);
EXPECT_GE(splice(p[0], NULL, self->fd, NULL, 8000, 0), 0); EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, 8000, 0), 8000);
EXPECT_GE(write(p2[1], mem_send + 8000, 8000), 0); EXPECT_EQ(write(p2[1], mem_send + 8000, 8000), 8000);
EXPECT_GE(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 0); EXPECT_EQ(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 8000);
EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len); EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0); EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
} }
...@@ -683,10 +685,12 @@ TEST_F(tls, splice_to_pipe) ...@@ -683,10 +685,12 @@ TEST_F(tls, splice_to_pipe)
char mem_recv[TLS_PAYLOAD_MAX_LEN]; char mem_recv[TLS_PAYLOAD_MAX_LEN];
int p[2]; int p[2];
memrnd(mem_send, sizeof(mem_send));
ASSERT_GE(pipe(p), 0); ASSERT_GE(pipe(p), 0);
EXPECT_GE(send(self->fd, mem_send, send_len, 0), 0); EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
EXPECT_GE(splice(self->cfd, NULL, p[1], NULL, send_len, 0), 0); EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), send_len);
EXPECT_GE(read(p[0], mem_recv, send_len), 0); EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0); EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
} }
...@@ -875,6 +879,8 @@ TEST_F(tls, multiple_send_single_recv) ...@@ -875,6 +879,8 @@ TEST_F(tls, multiple_send_single_recv)
char recv_mem[2 * 10]; char recv_mem[2 * 10];
char send_mem[10]; char send_mem[10];
memrnd(send_mem, sizeof(send_mem));
EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0); EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0); EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
memset(recv_mem, 0, total_len); memset(recv_mem, 0, total_len);
...@@ -891,6 +897,8 @@ TEST_F(tls, single_send_multiple_recv_non_align) ...@@ -891,6 +897,8 @@ TEST_F(tls, single_send_multiple_recv_non_align)
char recv_mem[recv_len * 2]; char recv_mem[recv_len * 2];
char send_mem[total_len]; char send_mem[total_len];
memrnd(send_mem, sizeof(send_mem));
EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0); EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
memset(recv_mem, 0, total_len); memset(recv_mem, 0, total_len);
...@@ -936,10 +944,10 @@ TEST_F(tls, recv_peek) ...@@ -936,10 +944,10 @@ TEST_F(tls, recv_peek)
char buf[15]; char buf[15];
EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len); EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1); EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), send_len);
EXPECT_EQ(memcmp(test_str, buf, send_len), 0); EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
memset(buf, 0, sizeof(buf)); memset(buf, 0, sizeof(buf));
EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1); EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
EXPECT_EQ(memcmp(test_str, buf, send_len), 0); EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment