Commit a94b5aae authored by Alexei Starovoitov's avatar Alexei Starovoitov

Merge branch 'sock_map: fix ->poll() and update selftests'

Cong Wang says:

====================
This patchset fixes ->poll() for sockets in sockmap and updates
selftests accordingly with select(). Please check each patch
for more details.

Fixes: c50524ec ("Merge branch 'sockmap: add sockmap support for unix datagram socket'")
Fixes: 89d69c5d ("Merge branch 'sockmap: introduce BPF_SK_SKB_VERDICT and support UDP'")
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>

---
v4: add a comment in udp_poll()

v3: drop sk_psock_get_checked()
    reuse tcp_bpf_sock_is_readable()

v2: rename and reuse ->stream_memory_read()
    fix a compile error in sk_psock_get_checked()

Cong Wang (3):
  net: rename ->stream_memory_read to ->sock_is_readable
  skmsg: extract and reuse sk_msg_is_readable()
  net: implement ->sock_is_readable() for UDP and AF_UNIX

====================
Reviewed-by: default avatarJakub Sitnicki <jakub@cloudflare.com>
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents cd9733f5 67b82150
...@@ -128,6 +128,7 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, ...@@ -128,6 +128,7 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes); struct sk_msg *msg, u32 bytes);
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags); int len, int flags);
bool sk_msg_is_readable(struct sock *sk);
static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes) static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
{ {
......
...@@ -1208,7 +1208,7 @@ struct proto { ...@@ -1208,7 +1208,7 @@ struct proto {
#endif #endif
bool (*stream_memory_free)(const struct sock *sk, int wake); bool (*stream_memory_free)(const struct sock *sk, int wake);
bool (*stream_memory_read)(const struct sock *sk); bool (*sock_is_readable)(struct sock *sk);
/* Memory pressure */ /* Memory pressure */
void (*enter_memory_pressure)(struct sock *sk); void (*enter_memory_pressure)(struct sock *sk);
void (*leave_memory_pressure)(struct sock *sk); void (*leave_memory_pressure)(struct sock *sk);
...@@ -2820,4 +2820,10 @@ void sock_set_sndtimeo(struct sock *sk, s64 secs); ...@@ -2820,4 +2820,10 @@ void sock_set_sndtimeo(struct sock *sk, s64 secs);
int sock_bind_add(struct sock *sk, struct sockaddr *addr, int addr_len); int sock_bind_add(struct sock *sk, struct sockaddr *addr, int addr_len);
static inline bool sk_is_readable(struct sock *sk)
{
if (sk->sk_prot->sock_is_readable)
return sk->sk_prot->sock_is_readable(sk);
return false;
}
#endif /* _SOCK_H */ #endif /* _SOCK_H */
...@@ -375,7 +375,7 @@ void tls_sw_release_resources_rx(struct sock *sk); ...@@ -375,7 +375,7 @@ void tls_sw_release_resources_rx(struct sock *sk);
void tls_sw_free_ctx_rx(struct tls_context *tls_ctx); void tls_sw_free_ctx_rx(struct tls_context *tls_ctx);
int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len); int nonblock, int flags, int *addr_len);
bool tls_sw_stream_read(const struct sock *sk); bool tls_sw_sock_is_readable(struct sock *sk);
ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct pipe_inode_info *pipe, struct pipe_inode_info *pipe,
size_t len, unsigned int flags); size_t len, unsigned int flags);
......
...@@ -474,6 +474,20 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, ...@@ -474,6 +474,20 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
} }
EXPORT_SYMBOL_GPL(sk_msg_recvmsg); EXPORT_SYMBOL_GPL(sk_msg_recvmsg);
bool sk_msg_is_readable(struct sock *sk)
{
struct sk_psock *psock;
bool empty = true;
rcu_read_lock();
psock = sk_psock(sk);
if (likely(psock))
empty = list_empty(&psock->ingress_msg);
rcu_read_unlock();
return !empty;
}
EXPORT_SYMBOL_GPL(sk_msg_is_readable);
static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk, static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
struct sk_buff *skb) struct sk_buff *skb)
{ {
......
...@@ -486,10 +486,7 @@ static bool tcp_stream_is_readable(struct sock *sk, int target) ...@@ -486,10 +486,7 @@ static bool tcp_stream_is_readable(struct sock *sk, int target)
{ {
if (tcp_epollin_ready(sk, target)) if (tcp_epollin_ready(sk, target))
return true; return true;
return sk_is_readable(sk);
if (sk->sk_prot->stream_memory_read)
return sk->sk_prot->stream_memory_read(sk);
return false;
} }
/* /*
......
...@@ -150,19 +150,6 @@ int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, ...@@ -150,19 +150,6 @@ int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
#ifdef CONFIG_BPF_SYSCALL #ifdef CONFIG_BPF_SYSCALL
static bool tcp_bpf_stream_read(const struct sock *sk)
{
struct sk_psock *psock;
bool empty = true;
rcu_read_lock();
psock = sk_psock(sk);
if (likely(psock))
empty = list_empty(&psock->ingress_msg);
rcu_read_unlock();
return !empty;
}
static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
long timeo) long timeo)
{ {
...@@ -491,7 +478,7 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], ...@@ -491,7 +478,7 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
prot[TCP_BPF_BASE].unhash = sock_map_unhash; prot[TCP_BPF_BASE].unhash = sock_map_unhash;
prot[TCP_BPF_BASE].close = sock_map_close; prot[TCP_BPF_BASE].close = sock_map_close;
prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable;
prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
......
...@@ -2867,6 +2867,9 @@ __poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait) ...@@ -2867,6 +2867,9 @@ __poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait)
!(sk->sk_shutdown & RCV_SHUTDOWN) && first_packet_length(sk) == -1) !(sk->sk_shutdown & RCV_SHUTDOWN) && first_packet_length(sk) == -1)
mask &= ~(EPOLLIN | EPOLLRDNORM); mask &= ~(EPOLLIN | EPOLLRDNORM);
/* psock ingress_msg queue should not contain any bad checksum frames */
if (sk_is_readable(sk))
mask |= EPOLLIN | EPOLLRDNORM;
return mask; return mask;
} }
......
...@@ -114,6 +114,7 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) ...@@ -114,6 +114,7 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
*prot = *base; *prot = *base;
prot->close = sock_map_close; prot->close = sock_map_close;
prot->recvmsg = udp_bpf_recvmsg; prot->recvmsg = udp_bpf_recvmsg;
prot->sock_is_readable = sk_msg_is_readable;
} }
static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
......
...@@ -681,12 +681,12 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], ...@@ -681,12 +681,12 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read; prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable;
prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;
prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
prot[TLS_SW][TLS_SW].stream_memory_read = tls_sw_stream_read; prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable;
prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
......
...@@ -2026,7 +2026,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2026,7 +2026,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
return copied ? : err; return copied ? : err;
} }
bool tls_sw_stream_read(const struct sock *sk) bool tls_sw_sock_is_readable(struct sock *sk)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
......
...@@ -3052,6 +3052,8 @@ static __poll_t unix_poll(struct file *file, struct socket *sock, poll_table *wa ...@@ -3052,6 +3052,8 @@ static __poll_t unix_poll(struct file *file, struct socket *sock, poll_table *wa
/* readable? */ /* readable? */
if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) if (!skb_queue_empty_lockless(&sk->sk_receive_queue))
mask |= EPOLLIN | EPOLLRDNORM; mask |= EPOLLIN | EPOLLRDNORM;
if (sk_is_readable(sk))
mask |= EPOLLIN | EPOLLRDNORM;
/* Connection-based need to check for termination and startup */ /* Connection-based need to check for termination and startup */
if ((sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) && if ((sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) &&
...@@ -3091,6 +3093,8 @@ static __poll_t unix_dgram_poll(struct file *file, struct socket *sock, ...@@ -3091,6 +3093,8 @@ static __poll_t unix_dgram_poll(struct file *file, struct socket *sock,
/* readable? */ /* readable? */
if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) if (!skb_queue_empty_lockless(&sk->sk_receive_queue))
mask |= EPOLLIN | EPOLLRDNORM; mask |= EPOLLIN | EPOLLRDNORM;
if (sk_is_readable(sk))
mask |= EPOLLIN | EPOLLRDNORM;
/* Connection-based need to check for termination and startup */ /* Connection-based need to check for termination and startup */
if (sk->sk_type == SOCK_SEQPACKET) { if (sk->sk_type == SOCK_SEQPACKET) {
......
...@@ -102,6 +102,7 @@ static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto ...@@ -102,6 +102,7 @@ static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto
*prot = *base; *prot = *base;
prot->close = sock_map_close; prot->close = sock_map_close;
prot->recvmsg = unix_bpf_recvmsg; prot->recvmsg = unix_bpf_recvmsg;
prot->sock_is_readable = sk_msg_is_readable;
} }
static void unix_stream_bpf_rebuild_protos(struct proto *prot, static void unix_stream_bpf_rebuild_protos(struct proto *prot,
...@@ -110,6 +111,7 @@ static void unix_stream_bpf_rebuild_protos(struct proto *prot, ...@@ -110,6 +111,7 @@ static void unix_stream_bpf_rebuild_protos(struct proto *prot,
*prot = *base; *prot = *base;
prot->close = sock_map_close; prot->close = sock_map_close;
prot->recvmsg = unix_bpf_recvmsg; prot->recvmsg = unix_bpf_recvmsg;
prot->sock_is_readable = sk_msg_is_readable;
prot->unhash = sock_map_unhash; prot->unhash = sock_map_unhash;
} }
......
...@@ -949,7 +949,6 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd, ...@@ -949,7 +949,6 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd,
int err, n; int err, n;
u32 key; u32 key;
char b; char b;
int retries = 100;
zero_verdict_count(verd_mapfd); zero_verdict_count(verd_mapfd);
...@@ -1002,17 +1001,11 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd, ...@@ -1002,17 +1001,11 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd,
goto close_peer1; goto close_peer1;
if (pass != 1) if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass); FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again: n = recv_timeout(c0, &b, 1, 0, IO_TIMEOUT_SEC);
n = read(c0, &b, 1); if (n < 0)
if (n < 0) { FAIL_ERRNO("%s: recv_timeout", log_prefix);
if (errno == EAGAIN && retries--) {
usleep(1000);
goto again;
}
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0) if (n == 0)
FAIL("%s: incomplete read", log_prefix); FAIL("%s: incomplete recv", log_prefix);
close_peer1: close_peer1:
xclose(p1); xclose(p1);
...@@ -1571,7 +1564,6 @@ static void unix_redir_to_connected(int sotype, int sock_mapfd, ...@@ -1571,7 +1564,6 @@ static void unix_redir_to_connected(int sotype, int sock_mapfd,
const char *log_prefix = redir_mode_str(mode); const char *log_prefix = redir_mode_str(mode);
int c0, c1, p0, p1; int c0, c1, p0, p1;
unsigned int pass; unsigned int pass;
int retries = 100;
int err, n; int err, n;
int sfd[2]; int sfd[2];
u32 key; u32 key;
...@@ -1606,17 +1598,11 @@ static void unix_redir_to_connected(int sotype, int sock_mapfd, ...@@ -1606,17 +1598,11 @@ static void unix_redir_to_connected(int sotype, int sock_mapfd,
if (pass != 1) if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass); FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again: n = recv_timeout(mode == REDIR_INGRESS ? p0 : c0, &b, 1, 0, IO_TIMEOUT_SEC);
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1); if (n < 0)
if (n < 0) { FAIL_ERRNO("%s: recv_timeout", log_prefix);
if (errno == EAGAIN && retries--) {
usleep(1000);
goto again;
}
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0) if (n == 0)
FAIL("%s: incomplete read", log_prefix); FAIL("%s: incomplete recv", log_prefix);
close: close:
xclose(c1); xclose(c1);
...@@ -1748,7 +1734,6 @@ static void udp_redir_to_connected(int family, int sock_mapfd, int verd_mapfd, ...@@ -1748,7 +1734,6 @@ static void udp_redir_to_connected(int family, int sock_mapfd, int verd_mapfd,
const char *log_prefix = redir_mode_str(mode); const char *log_prefix = redir_mode_str(mode);
int c0, c1, p0, p1; int c0, c1, p0, p1;
unsigned int pass; unsigned int pass;
int retries = 100;
int err, n; int err, n;
u32 key; u32 key;
char b; char b;
...@@ -1781,17 +1766,11 @@ static void udp_redir_to_connected(int family, int sock_mapfd, int verd_mapfd, ...@@ -1781,17 +1766,11 @@ static void udp_redir_to_connected(int family, int sock_mapfd, int verd_mapfd,
if (pass != 1) if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass); FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again: n = recv_timeout(mode == REDIR_INGRESS ? p0 : c0, &b, 1, 0, IO_TIMEOUT_SEC);
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1); if (n < 0)
if (n < 0) { FAIL_ERRNO("%s: recv_timeout", log_prefix);
if (errno == EAGAIN && retries--) {
usleep(1000);
goto again;
}
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0) if (n == 0)
FAIL("%s: incomplete read", log_prefix); FAIL("%s: incomplete recv", log_prefix);
close_cli1: close_cli1:
xclose(c1); xclose(c1);
...@@ -1841,7 +1820,6 @@ static void inet_unix_redir_to_connected(int family, int type, int sock_mapfd, ...@@ -1841,7 +1820,6 @@ static void inet_unix_redir_to_connected(int family, int type, int sock_mapfd,
const char *log_prefix = redir_mode_str(mode); const char *log_prefix = redir_mode_str(mode);
int c0, c1, p0, p1; int c0, c1, p0, p1;
unsigned int pass; unsigned int pass;
int retries = 100;
int err, n; int err, n;
int sfd[2]; int sfd[2];
u32 key; u32 key;
...@@ -1876,17 +1854,11 @@ static void inet_unix_redir_to_connected(int family, int type, int sock_mapfd, ...@@ -1876,17 +1854,11 @@ static void inet_unix_redir_to_connected(int family, int type, int sock_mapfd,
if (pass != 1) if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass); FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again: n = recv_timeout(mode == REDIR_INGRESS ? p0 : c0, &b, 1, 0, IO_TIMEOUT_SEC);
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1); if (n < 0)
if (n < 0) { FAIL_ERRNO("%s: recv_timeout", log_prefix);
if (errno == EAGAIN && retries--) {
usleep(1000);
goto again;
}
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0) if (n == 0)
FAIL("%s: incomplete read", log_prefix); FAIL("%s: incomplete recv", log_prefix);
close_cli1: close_cli1:
xclose(c1); xclose(c1);
...@@ -1932,7 +1904,6 @@ static void unix_inet_redir_to_connected(int family, int type, int sock_mapfd, ...@@ -1932,7 +1904,6 @@ static void unix_inet_redir_to_connected(int family, int type, int sock_mapfd,
int sfd[2]; int sfd[2];
u32 key; u32 key;
char b; char b;
int retries = 100;
zero_verdict_count(verd_mapfd); zero_verdict_count(verd_mapfd);
...@@ -1963,17 +1934,11 @@ static void unix_inet_redir_to_connected(int family, int type, int sock_mapfd, ...@@ -1963,17 +1934,11 @@ static void unix_inet_redir_to_connected(int family, int type, int sock_mapfd,
if (pass != 1) if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass); FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again: n = recv_timeout(mode == REDIR_INGRESS ? p0 : c0, &b, 1, 0, IO_TIMEOUT_SEC);
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1); if (n < 0)
if (n < 0) { FAIL_ERRNO("%s: recv_timeout", log_prefix);
if (errno == EAGAIN && retries--) {
usleep(1000);
goto again;
}
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0) if (n == 0)
FAIL("%s: incomplete read", log_prefix); FAIL("%s: incomplete recv", log_prefix);
close: close:
xclose(c1); xclose(c1);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment