Commit 8a59f9d1 authored by Cong Wang's avatar Cong Wang Committed by Alexei Starovoitov

sock: Introduce sk->sk_prot->psock_update_sk_prot()

Currently sockmap calls into each protocol to update the struct
proto and replace it. This certainly won't work when the protocol
is implemented as a module, for example, AF_UNIX.

Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each
protocol can implement its own way to replace the struct proto.
This also helps get rid of symbol dependencies on CONFIG_INET.
Signed-off-by: default avatarCong Wang <cong.wang@bytedance.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20210331023237.41094-11-xiyou.wangcong@gmail.com
parent a7ba4558
...@@ -99,6 +99,7 @@ struct sk_psock { ...@@ -99,6 +99,7 @@ struct sk_psock {
void (*saved_close)(struct sock *sk, long timeout); void (*saved_close)(struct sock *sk, long timeout);
void (*saved_write_space)(struct sock *sk); void (*saved_write_space)(struct sock *sk);
void (*saved_data_ready)(struct sock *sk); void (*saved_data_ready)(struct sock *sk);
int (*psock_update_sk_prot)(struct sock *sk, bool restore);
struct proto *sk_proto; struct proto *sk_proto;
struct mutex work_mutex; struct mutex work_mutex;
struct sk_psock_work_state work_state; struct sk_psock_work_state work_state;
...@@ -395,25 +396,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock) ...@@ -395,25 +396,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
} }
} }
static inline void sk_psock_update_proto(struct sock *sk,
struct sk_psock *psock,
struct proto *ops)
{
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, ops);
}
static inline void sk_psock_restore_proto(struct sock *sk, static inline void sk_psock_restore_proto(struct sock *sk,
struct sk_psock *psock) struct sk_psock *psock)
{ {
sk->sk_prot->unhash = psock->saved_unhash; sk->sk_prot->unhash = psock->saved_unhash;
if (inet_csk_has_ulp(sk)) { if (psock->psock_update_sk_prot)
tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space); psock->psock_update_sk_prot(sk, true);
} else {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
}
} }
static inline void sk_psock_set_state(struct sk_psock *psock, static inline void sk_psock_set_state(struct sk_psock *psock,
......
...@@ -1184,6 +1184,9 @@ struct proto { ...@@ -1184,6 +1184,9 @@ struct proto {
void (*unhash)(struct sock *sk); void (*unhash)(struct sock *sk);
void (*rehash)(struct sock *sk); void (*rehash)(struct sock *sk);
int (*get_port)(struct sock *sk, unsigned short snum); int (*get_port)(struct sock *sk, unsigned short snum);
#ifdef CONFIG_BPF_SYSCALL
int (*psock_update_sk_prot)(struct sock *sk, bool restore);
#endif
/* Keeping track of sockets in use */ /* Keeping track of sockets in use */
#ifdef CONFIG_PROC_FS #ifdef CONFIG_PROC_FS
......
...@@ -2203,6 +2203,7 @@ struct sk_psock; ...@@ -2203,6 +2203,7 @@ struct sk_psock;
#ifdef CONFIG_BPF_SYSCALL #ifdef CONFIG_BPF_SYSCALL
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock); struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
int tcp_bpf_update_proto(struct sock *sk, bool restore);
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk); void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
#endif /* CONFIG_BPF_SYSCALL */ #endif /* CONFIG_BPF_SYSCALL */
......
...@@ -518,6 +518,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk, ...@@ -518,6 +518,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk,
#ifdef CONFIG_BPF_SYSCALL #ifdef CONFIG_BPF_SYSCALL
struct sk_psock; struct sk_psock;
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock); struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
int udp_bpf_update_proto(struct sock *sk, bool restore);
#endif #endif
#endif /* _UDP_H */ #endif /* _UDP_H */
...@@ -562,11 +562,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node) ...@@ -562,11 +562,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
if (inet_csk_has_ulp(sk)) {
psock = ERR_PTR(-EINVAL);
goto out;
}
if (sk->sk_user_data) { if (sk->sk_user_data) {
psock = ERR_PTR(-EBUSY); psock = ERR_PTR(-EBUSY);
goto out; goto out;
......
...@@ -185,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw) ...@@ -185,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
{ {
struct proto *prot; if (!sk->sk_prot->psock_update_sk_prot)
switch (sk->sk_type) {
case SOCK_STREAM:
prot = tcp_bpf_get_proto(sk, psock);
break;
case SOCK_DGRAM:
prot = udp_bpf_get_proto(sk, psock);
break;
default:
return -EINVAL; return -EINVAL;
} psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
return sk->sk_prot->psock_update_sk_prot(sk, false);
if (IS_ERR(prot))
return PTR_ERR(prot);
sk_psock_update_proto(sk, psock, prot);
return 0;
} }
static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
...@@ -556,7 +540,7 @@ static bool sock_map_redirect_allowed(const struct sock *sk) ...@@ -556,7 +540,7 @@ static bool sock_map_redirect_allowed(const struct sock *sk)
static bool sock_map_sk_is_suitable(const struct sock *sk) static bool sock_map_sk_is_suitable(const struct sock *sk)
{ {
return sk_is_tcp(sk) || sk_is_udp(sk); return !!sk->sk_prot->psock_update_sk_prot;
} }
static bool sock_map_sk_state_allowed(const struct sock *sk) static bool sock_map_sk_state_allowed(const struct sock *sk)
......
...@@ -595,20 +595,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops) ...@@ -595,20 +595,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
} }
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) int tcp_bpf_update_proto(struct sock *sk, bool restore)
{ {
struct sk_psock *psock = sk_psock(sk);
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
if (restore) {
if (inet_csk_has_ulp(sk)) {
tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
} else {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
}
return 0;
}
if (inet_csk_has_ulp(sk))
return -EINVAL;
if (sk->sk_family == AF_INET6) { if (sk->sk_family == AF_INET6) {
if (tcp_bpf_assert_proto_ops(psock->sk_proto)) if (tcp_bpf_assert_proto_ops(psock->sk_proto))
return ERR_PTR(-EINVAL); return -EINVAL;
tcp_bpf_check_v6_needs_rebuild(psock->sk_proto); tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
} }
return &tcp_bpf_prots[family][config]; /* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
return 0;
} }
EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
/* If a child got cloned from a listening socket that had tcp_bpf /* If a child got cloned from a listening socket that had tcp_bpf
* protocol callbacks installed, we need to restore the callbacks to * protocol callbacks installed, we need to restore the callbacks to
......
...@@ -2806,6 +2806,9 @@ struct proto tcp_prot = { ...@@ -2806,6 +2806,9 @@ struct proto tcp_prot = {
.hash = inet_hash, .hash = inet_hash,
.unhash = inet_unhash, .unhash = inet_unhash,
.get_port = inet_csk_get_port, .get_port = inet_csk_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = tcp_bpf_update_proto,
#endif
.enter_memory_pressure = tcp_enter_memory_pressure, .enter_memory_pressure = tcp_enter_memory_pressure,
.leave_memory_pressure = tcp_leave_memory_pressure, .leave_memory_pressure = tcp_leave_memory_pressure,
.stream_memory_free = tcp_stream_memory_free, .stream_memory_free = tcp_stream_memory_free,
......
...@@ -2849,6 +2849,9 @@ struct proto udp_prot = { ...@@ -2849,6 +2849,9 @@ struct proto udp_prot = {
.unhash = udp_lib_unhash, .unhash = udp_lib_unhash,
.rehash = udp_v4_rehash, .rehash = udp_v4_rehash,
.get_port = udp_v4_get_port, .get_port = udp_v4_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = udp_bpf_update_proto,
#endif
.memory_allocated = &udp_memory_allocated, .memory_allocated = &udp_memory_allocated,
.sysctl_mem = sysctl_udp_mem, .sysctl_mem = sysctl_udp_mem,
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
......
...@@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void) ...@@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void)
} }
core_initcall(udp_bpf_v4_build_proto); core_initcall(udp_bpf_v4_build_proto);
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) int udp_bpf_update_proto(struct sock *sk, bool restore)
{ {
int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
struct sk_psock *psock = sk_psock(sk);
if (restore) {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
return 0;
}
if (sk->sk_family == AF_INET6) if (sk->sk_family == AF_INET6)
udp_bpf_check_v6_needs_rebuild(psock->sk_proto); udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
return &udp_bpf_prots[family]; /* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
return 0;
} }
EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
...@@ -2139,6 +2139,9 @@ struct proto tcpv6_prot = { ...@@ -2139,6 +2139,9 @@ struct proto tcpv6_prot = {
.hash = inet6_hash, .hash = inet6_hash,
.unhash = inet_unhash, .unhash = inet_unhash,
.get_port = inet_csk_get_port, .get_port = inet_csk_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = tcp_bpf_update_proto,
#endif
.enter_memory_pressure = tcp_enter_memory_pressure, .enter_memory_pressure = tcp_enter_memory_pressure,
.leave_memory_pressure = tcp_leave_memory_pressure, .leave_memory_pressure = tcp_leave_memory_pressure,
.stream_memory_free = tcp_stream_memory_free, .stream_memory_free = tcp_stream_memory_free,
......
...@@ -1713,6 +1713,9 @@ struct proto udpv6_prot = { ...@@ -1713,6 +1713,9 @@ struct proto udpv6_prot = {
.unhash = udp_lib_unhash, .unhash = udp_lib_unhash,
.rehash = udp_v6_rehash, .rehash = udp_v6_rehash,
.get_port = udp_v6_get_port, .get_port = udp_v6_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = udp_bpf_update_proto,
#endif
.memory_allocated = &udp_memory_allocated, .memory_allocated = &udp_memory_allocated,
.sysctl_mem = sysctl_udp_mem, .sysctl_mem = sysctl_udp_mem,
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment