Commit e2859601 authored by David S. Miller's avatar David S. Miller

Merge branch 'net-tls-Annotate-lockless-access-to-sk_prot'

Jakub Sitnicki says:

====================
net/tls: Annotate lockless access to sk_prot

We have recently noticed that there is a case of lockless read/write to
sk->sk_prot [0]. sockmap code on psock tear-down writes to sk->sk_prot,
while holding sk_callback_lock. Concurrently, tcp can access it. Usually to
read out the sk_prot pointer and invoke one of the ops,
sk->sk_prot->handler().

The lockless write (lockless in regard to concurrent reads) happens on the
following paths:

tcp_bpf_{recvmsg|sendmsg} / sock_map_unref
  sk_psock_put
    sk_psock_drop
      sk_psock_restore_proto
        WRITE_ONCE(sk->sk_prot, proto)

To prevent load/store tearing [1], and to make tooling aware of intentional
shared access [2], we need to annotate sites that access sk_prot with
READ_ONCE/WRITE_ONCE.

This series kicks off the effort to do it. Starting with net/tls.

[0] https://lore.kernel.org/bpf/a6bf279e-a998-84ab-4371-cd6c1ccbca5d@gmail.com/
[1] https://lwn.net/Articles/793253/
[2] https://github.com/google/ktsan/wiki/READ_ONCE-and-WRITE_ONCE
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 26922c0e d5bee737
...@@ -366,7 +366,7 @@ static int tls_do_allocation(struct sock *sk, ...@@ -366,7 +366,7 @@ static int tls_do_allocation(struct sock *sk,
if (!offload_ctx->open_record) { if (!offload_ctx->open_record) {
if (unlikely(!skb_page_frag_refill(prepend_size, pfrag, if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
sk->sk_allocation))) { sk->sk_allocation))) {
sk->sk_prot->enter_memory_pressure(sk); READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
sk_stream_moderate_sndbuf(sk); sk_stream_moderate_sndbuf(sk);
return -ENOMEM; return -ENOMEM;
} }
......
...@@ -63,13 +63,14 @@ static DEFINE_MUTEX(tcpv4_prot_mutex); ...@@ -63,13 +63,14 @@ static DEFINE_MUTEX(tcpv4_prot_mutex);
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static struct proto_ops tls_sw_proto_ops; static struct proto_ops tls_sw_proto_ops;
static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
struct proto *base); const struct proto *base);
void update_sk_prot(struct sock *sk, struct tls_context *ctx) void update_sk_prot(struct sock *sk, struct tls_context *ctx)
{ {
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]; WRITE_ONCE(sk->sk_prot,
&tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
} }
int wait_on_pending_writer(struct sock *sk, long *timeo) int wait_on_pending_writer(struct sock *sk, long *timeo)
...@@ -312,7 +313,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -312,7 +313,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
if (free_ctx) if (free_ctx)
rcu_assign_pointer(icsk->icsk_ulp_data, NULL); rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
sk->sk_prot = ctx->sk_proto; WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
if (sk->sk_write_space == tls_write_space) if (sk->sk_write_space == tls_write_space)
sk->sk_write_space = ctx->sk_write_space; sk->sk_write_space = ctx->sk_write_space;
write_unlock_bh(&sk->sk_callback_lock); write_unlock_bh(&sk->sk_callback_lock);
...@@ -621,38 +622,39 @@ struct tls_context *tls_ctx_create(struct sock *sk) ...@@ -621,38 +622,39 @@ struct tls_context *tls_ctx_create(struct sock *sk)
mutex_init(&ctx->tx_lock); mutex_init(&ctx->tx_lock);
rcu_assign_pointer(icsk->icsk_ulp_data, ctx); rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
ctx->sk_proto = sk->sk_prot; ctx->sk_proto = READ_ONCE(sk->sk_prot);
return ctx; return ctx;
} }
static void tls_build_proto(struct sock *sk) static void tls_build_proto(struct sock *sk)
{ {
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
const struct proto *prot = READ_ONCE(sk->sk_prot);
/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
if (ip_ver == TLSV6 && if (ip_ver == TLSV6 &&
unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
mutex_lock(&tcpv6_prot_mutex); mutex_lock(&tcpv6_prot_mutex);
if (likely(sk->sk_prot != saved_tcpv6_prot)) { if (likely(prot != saved_tcpv6_prot)) {
build_protos(tls_prots[TLSV6], sk->sk_prot); build_protos(tls_prots[TLSV6], prot);
smp_store_release(&saved_tcpv6_prot, sk->sk_prot); smp_store_release(&saved_tcpv6_prot, prot);
} }
mutex_unlock(&tcpv6_prot_mutex); mutex_unlock(&tcpv6_prot_mutex);
} }
if (ip_ver == TLSV4 && if (ip_ver == TLSV4 &&
unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) { unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
mutex_lock(&tcpv4_prot_mutex); mutex_lock(&tcpv4_prot_mutex);
if (likely(sk->sk_prot != saved_tcpv4_prot)) { if (likely(prot != saved_tcpv4_prot)) {
build_protos(tls_prots[TLSV4], sk->sk_prot); build_protos(tls_prots[TLSV4], prot);
smp_store_release(&saved_tcpv4_prot, sk->sk_prot); smp_store_release(&saved_tcpv4_prot, prot);
} }
mutex_unlock(&tcpv4_prot_mutex); mutex_unlock(&tcpv4_prot_mutex);
} }
} }
static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
struct proto *base) const struct proto *base)
{ {
prot[TLS_BASE][TLS_BASE] = *base; prot[TLS_BASE][TLS_BASE] = *base;
prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment