Commit 417fa6d1 authored by Martin KaFai Lau's avatar Martin KaFai Lau

Merge branch 'fix sockmap + stream af_unix memleak'

John Fastabend says:

====================
There was a memleak when streaming af_unix sockets were inserted into
multiple sockmap slots and/or maps. This is because each insert would
call a proto update operatino and these must be allowed to be called
multiple times. The streaming af_unix implementation recently added
a refcnt to handle a use after free issue, however it introduced a
memleak when inserted into multiple maps.

This series fixes the memleak, adds a note in the code so we remember
that proto updates need to support this. And then we add three tests
for each of the slightly different iterations of adding sockets into
multiple maps. I kept them as 3 independent test cases here. I have
some slight preference for this they could however be a single test,
but then you don't get to run them independently which was sort of
useful while debugging.
====================
Signed-off-by: default avatarMartin KaFai Lau <martin.lau@kernel.org>
parents b4560055 bdbca46d
...@@ -100,6 +100,11 @@ struct sk_psock { ...@@ -100,6 +100,11 @@ struct sk_psock {
void (*saved_close)(struct sock *sk, long timeout); void (*saved_close)(struct sock *sk, long timeout);
void (*saved_write_space)(struct sock *sk); void (*saved_write_space)(struct sock *sk);
void (*saved_data_ready)(struct sock *sk); void (*saved_data_ready)(struct sock *sk);
/* psock_update_sk_prot may be called with restore=false many times
* so the handler must be safe for this case. It will be called
* exactly once with restore=true when the psock is being destroyed
* and psock refcnt is zero, but before an RCU grace period.
*/
int (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock, int (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
bool restore); bool restore);
struct proto *sk_proto; struct proto *sk_proto;
......
...@@ -161,15 +161,30 @@ int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool r ...@@ -161,15 +161,30 @@ int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool r
{ {
struct sock *sk_pair; struct sock *sk_pair;
/* Restore does not decrement the sk_pair reference yet because we must
* keep the a reference to the socket until after an RCU grace period
* and any pending sends have completed.
*/
if (restore) { if (restore) {
sk->sk_write_space = psock->saved_write_space; sk->sk_write_space = psock->saved_write_space;
sock_replace_proto(sk, psock->sk_proto); sock_replace_proto(sk, psock->sk_proto);
return 0; return 0;
} }
/* psock_update_sk_prot can be called multiple times if psock is
* added to multiple maps and/or slots in the same map. There is
* also an edge case where replacing a psock with itself can trigger
* an extra psock_update_sk_prot during the insert process. So it
* must be safe to do multiple calls. Here we need to ensure we don't
* increment the refcnt through sock_hold many times. There will only
* be a single matching destroy operation.
*/
if (!psock->sk_pair) {
sk_pair = unix_peer(sk); sk_pair = unix_peer(sk);
sock_hold(sk_pair); sock_hold(sk_pair);
psock->sk_pair = sk_pair; psock->sk_pair = sk_pair;
}
unix_stream_bpf_check_needs_rebuild(psock->sk_proto); unix_stream_bpf_check_needs_rebuild(psock->sk_proto);
sock_replace_proto(sk, &unix_stream_bpf_prot); sock_replace_proto(sk, &unix_stream_bpf_prot);
return 0; return 0;
......
...@@ -555,6 +555,213 @@ static void test_sockmap_unconnected_unix(void) ...@@ -555,6 +555,213 @@ static void test_sockmap_unconnected_unix(void)
close(dgram); close(dgram);
} }
static void test_sockmap_many_socket(void)
{
struct test_sockmap_pass_prog *skel;
int stream[2], dgram, udp, tcp;
int i, err, map, entry = 0;
skel = test_sockmap_pass_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "open_and_load"))
return;
map = bpf_map__fd(skel->maps.sock_map_rx);
dgram = xsocket(AF_UNIX, SOCK_DGRAM, 0);
if (dgram < 0) {
test_sockmap_pass_prog__destroy(skel);
return;
}
tcp = connected_socket_v4();
if (!ASSERT_GE(tcp, 0, "connected_socket_v4")) {
close(dgram);
test_sockmap_pass_prog__destroy(skel);
return;
}
udp = xsocket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0);
if (udp < 0) {
close(dgram);
close(tcp);
test_sockmap_pass_prog__destroy(skel);
return;
}
err = socketpair(AF_UNIX, SOCK_STREAM, 0, stream);
ASSERT_OK(err, "socketpair(af_unix, sock_stream)");
if (err)
goto out;
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map, &entry, &stream[0], BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(stream)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map, &entry, &dgram, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(dgram)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map, &entry, &udp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(udp)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map, &entry, &tcp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(tcp)");
}
for (entry--; entry >= 0; entry--) {
err = bpf_map_delete_elem(map, &entry);
ASSERT_OK(err, "bpf_map_delete_elem(entry)");
}
close(stream[0]);
close(stream[1]);
out:
close(dgram);
close(tcp);
close(udp);
test_sockmap_pass_prog__destroy(skel);
}
static void test_sockmap_many_maps(void)
{
struct test_sockmap_pass_prog *skel;
int stream[2], dgram, udp, tcp;
int i, err, map[2], entry = 0;
skel = test_sockmap_pass_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "open_and_load"))
return;
map[0] = bpf_map__fd(skel->maps.sock_map_rx);
map[1] = bpf_map__fd(skel->maps.sock_map_tx);
dgram = xsocket(AF_UNIX, SOCK_DGRAM, 0);
if (dgram < 0) {
test_sockmap_pass_prog__destroy(skel);
return;
}
tcp = connected_socket_v4();
if (!ASSERT_GE(tcp, 0, "connected_socket_v4")) {
close(dgram);
test_sockmap_pass_prog__destroy(skel);
return;
}
udp = xsocket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0);
if (udp < 0) {
close(dgram);
close(tcp);
test_sockmap_pass_prog__destroy(skel);
return;
}
err = socketpair(AF_UNIX, SOCK_STREAM, 0, stream);
ASSERT_OK(err, "socketpair(af_unix, sock_stream)");
if (err)
goto out;
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map[i], &entry, &stream[0], BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(stream)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map[i], &entry, &dgram, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(dgram)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map[i], &entry, &udp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(udp)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map[i], &entry, &tcp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(tcp)");
}
for (entry--; entry >= 0; entry--) {
err = bpf_map_delete_elem(map[1], &entry);
entry--;
ASSERT_OK(err, "bpf_map_delete_elem(entry)");
err = bpf_map_delete_elem(map[0], &entry);
ASSERT_OK(err, "bpf_map_delete_elem(entry)");
}
close(stream[0]);
close(stream[1]);
out:
close(dgram);
close(tcp);
close(udp);
test_sockmap_pass_prog__destroy(skel);
}
static void test_sockmap_same_sock(void)
{
struct test_sockmap_pass_prog *skel;
int stream[2], dgram, udp, tcp;
int i, err, map, zero = 0;
skel = test_sockmap_pass_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "open_and_load"))
return;
map = bpf_map__fd(skel->maps.sock_map_rx);
dgram = xsocket(AF_UNIX, SOCK_DGRAM, 0);
if (dgram < 0) {
test_sockmap_pass_prog__destroy(skel);
return;
}
tcp = connected_socket_v4();
if (!ASSERT_GE(tcp, 0, "connected_socket_v4")) {
close(dgram);
test_sockmap_pass_prog__destroy(skel);
return;
}
udp = xsocket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0);
if (udp < 0) {
close(dgram);
close(tcp);
test_sockmap_pass_prog__destroy(skel);
return;
}
err = socketpair(AF_UNIX, SOCK_STREAM, 0, stream);
ASSERT_OK(err, "socketpair(af_unix, sock_stream)");
if (err)
goto out;
for (i = 0; i < 2; i++) {
err = bpf_map_update_elem(map, &zero, &stream[0], BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(stream)");
}
for (i = 0; i < 2; i++) {
err = bpf_map_update_elem(map, &zero, &dgram, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(dgram)");
}
for (i = 0; i < 2; i++) {
err = bpf_map_update_elem(map, &zero, &udp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(udp)");
}
for (i = 0; i < 2; i++) {
err = bpf_map_update_elem(map, &zero, &tcp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(tcp)");
}
err = bpf_map_delete_elem(map, &zero);
ASSERT_OK(err, "bpf_map_delete_elem(entry)");
close(stream[0]);
close(stream[1]);
out:
close(dgram);
close(tcp);
close(udp);
test_sockmap_pass_prog__destroy(skel);
}
void test_sockmap_basic(void) void test_sockmap_basic(void)
{ {
if (test__start_subtest("sockmap create_update_free")) if (test__start_subtest("sockmap create_update_free"))
...@@ -597,7 +804,12 @@ void test_sockmap_basic(void) ...@@ -597,7 +804,12 @@ void test_sockmap_basic(void)
test_sockmap_skb_verdict_fionread(false); test_sockmap_skb_verdict_fionread(false);
if (test__start_subtest("sockmap skb_verdict msg_f_peek")) if (test__start_subtest("sockmap skb_verdict msg_f_peek"))
test_sockmap_skb_verdict_peek(); test_sockmap_skb_verdict_peek();
if (test__start_subtest("sockmap unconnected af_unix")) if (test__start_subtest("sockmap unconnected af_unix"))
test_sockmap_unconnected_unix(); test_sockmap_unconnected_unix();
if (test__start_subtest("sockmap one socket to many map entries"))
test_sockmap_many_socket();
if (test__start_subtest("sockmap one socket to many maps"))
test_sockmap_many_maps();
if (test__start_subtest("sockmap same socket replace"))
test_sockmap_same_sock();
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment