Commit 2bcd3502 authored by David S. Miller's avatar David S. Miller

Merge branch 'wg-fixes'

Jason A. Donenfeld says:

====================
wireguard fixes for 5.8-rc3

This series contains two fixes, one cosmetic and one quite important:

1) Avoid the `if ((x = f()) == y)` pattern, from Frank
   Werner-Krippendorf.

2) Mitigate a potential memory leak by creating circular netns
   references, while also making the netns semantics a bit more
   robust.

Patch (2) has a "Fixes:" line and should be backported to stable.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents f7fb92ac 900575aa
...@@ -45,17 +45,18 @@ static int wg_open(struct net_device *dev) ...@@ -45,17 +45,18 @@ static int wg_open(struct net_device *dev)
if (dev_v6) if (dev_v6)
dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE; dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;
mutex_lock(&wg->device_update_lock);
ret = wg_socket_init(wg, wg->incoming_port); ret = wg_socket_init(wg, wg->incoming_port);
if (ret < 0) if (ret < 0)
return ret; goto out;
mutex_lock(&wg->device_update_lock);
list_for_each_entry(peer, &wg->peer_list, peer_list) { list_for_each_entry(peer, &wg->peer_list, peer_list) {
wg_packet_send_staged_packets(peer); wg_packet_send_staged_packets(peer);
if (peer->persistent_keepalive_interval) if (peer->persistent_keepalive_interval)
wg_packet_send_keepalive(peer); wg_packet_send_keepalive(peer);
} }
out:
mutex_unlock(&wg->device_update_lock); mutex_unlock(&wg->device_update_lock);
return 0; return ret;
} }
#ifdef CONFIG_PM_SLEEP #ifdef CONFIG_PM_SLEEP
...@@ -225,6 +226,7 @@ static void wg_destruct(struct net_device *dev) ...@@ -225,6 +226,7 @@ static void wg_destruct(struct net_device *dev)
list_del(&wg->device_list); list_del(&wg->device_list);
rtnl_unlock(); rtnl_unlock();
mutex_lock(&wg->device_update_lock); mutex_lock(&wg->device_update_lock);
rcu_assign_pointer(wg->creating_net, NULL);
wg->incoming_port = 0; wg->incoming_port = 0;
wg_socket_reinit(wg, NULL, NULL); wg_socket_reinit(wg, NULL, NULL);
/* The final references are cleared in the below calls to destroy_workqueue. */ /* The final references are cleared in the below calls to destroy_workqueue. */
...@@ -240,13 +242,11 @@ static void wg_destruct(struct net_device *dev) ...@@ -240,13 +242,11 @@ static void wg_destruct(struct net_device *dev)
skb_queue_purge(&wg->incoming_handshakes); skb_queue_purge(&wg->incoming_handshakes);
free_percpu(dev->tstats); free_percpu(dev->tstats);
free_percpu(wg->incoming_handshakes_worker); free_percpu(wg->incoming_handshakes_worker);
if (wg->have_creating_net_ref)
put_net(wg->creating_net);
kvfree(wg->index_hashtable); kvfree(wg->index_hashtable);
kvfree(wg->peer_hashtable); kvfree(wg->peer_hashtable);
mutex_unlock(&wg->device_update_lock); mutex_unlock(&wg->device_update_lock);
pr_debug("%s: Interface deleted\n", dev->name); pr_debug("%s: Interface destroyed\n", dev->name);
free_netdev(dev); free_netdev(dev);
} }
...@@ -292,7 +292,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, ...@@ -292,7 +292,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
struct wg_device *wg = netdev_priv(dev); struct wg_device *wg = netdev_priv(dev);
int ret = -ENOMEM; int ret = -ENOMEM;
wg->creating_net = src_net; rcu_assign_pointer(wg->creating_net, src_net);
init_rwsem(&wg->static_identity.lock); init_rwsem(&wg->static_identity.lock);
mutex_init(&wg->socket_update_lock); mutex_init(&wg->socket_update_lock);
mutex_init(&wg->device_update_lock); mutex_init(&wg->device_update_lock);
...@@ -393,30 +393,26 @@ static struct rtnl_link_ops link_ops __read_mostly = { ...@@ -393,30 +393,26 @@ static struct rtnl_link_ops link_ops __read_mostly = {
.newlink = wg_newlink, .newlink = wg_newlink,
}; };
static int wg_netdevice_notification(struct notifier_block *nb, static void wg_netns_pre_exit(struct net *net)
unsigned long action, void *data)
{ {
struct net_device *dev = ((struct netdev_notifier_info *)data)->dev; struct wg_device *wg;
struct wg_device *wg = netdev_priv(dev);
ASSERT_RTNL();
if (action != NETDEV_REGISTER || dev->netdev_ops != &netdev_ops)
return 0;
if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) { rtnl_lock();
put_net(wg->creating_net); list_for_each_entry(wg, &device_list, device_list) {
wg->have_creating_net_ref = false; if (rcu_access_pointer(wg->creating_net) == net) {
} else if (dev_net(dev) != wg->creating_net && pr_debug("%s: Creating namespace exiting\n", wg->dev->name);
!wg->have_creating_net_ref) { netif_carrier_off(wg->dev);
wg->have_creating_net_ref = true; mutex_lock(&wg->device_update_lock);
get_net(wg->creating_net); rcu_assign_pointer(wg->creating_net, NULL);
wg_socket_reinit(wg, NULL, NULL);
mutex_unlock(&wg->device_update_lock);
}
} }
return 0; rtnl_unlock();
} }
static struct notifier_block netdevice_notifier = { static struct pernet_operations pernet_ops = {
.notifier_call = wg_netdevice_notification .pre_exit = wg_netns_pre_exit
}; };
int __init wg_device_init(void) int __init wg_device_init(void)
...@@ -429,18 +425,18 @@ int __init wg_device_init(void) ...@@ -429,18 +425,18 @@ int __init wg_device_init(void)
return ret; return ret;
#endif #endif
ret = register_netdevice_notifier(&netdevice_notifier); ret = register_pernet_device(&pernet_ops);
if (ret) if (ret)
goto error_pm; goto error_pm;
ret = rtnl_link_register(&link_ops); ret = rtnl_link_register(&link_ops);
if (ret) if (ret)
goto error_netdevice; goto error_pernet;
return 0; return 0;
error_netdevice: error_pernet:
unregister_netdevice_notifier(&netdevice_notifier); unregister_pernet_device(&pernet_ops);
error_pm: error_pm:
#ifdef CONFIG_PM_SLEEP #ifdef CONFIG_PM_SLEEP
unregister_pm_notifier(&pm_notifier); unregister_pm_notifier(&pm_notifier);
...@@ -451,7 +447,7 @@ int __init wg_device_init(void) ...@@ -451,7 +447,7 @@ int __init wg_device_init(void)
void wg_device_uninit(void) void wg_device_uninit(void)
{ {
rtnl_link_unregister(&link_ops); rtnl_link_unregister(&link_ops);
unregister_netdevice_notifier(&netdevice_notifier); unregister_pernet_device(&pernet_ops);
#ifdef CONFIG_PM_SLEEP #ifdef CONFIG_PM_SLEEP
unregister_pm_notifier(&pm_notifier); unregister_pm_notifier(&pm_notifier);
#endif #endif
......
...@@ -40,7 +40,7 @@ struct wg_device { ...@@ -40,7 +40,7 @@ struct wg_device {
struct net_device *dev; struct net_device *dev;
struct crypt_queue encrypt_queue, decrypt_queue; struct crypt_queue encrypt_queue, decrypt_queue;
struct sock __rcu *sock4, *sock6; struct sock __rcu *sock4, *sock6;
struct net *creating_net; struct net __rcu *creating_net;
struct noise_static_identity static_identity; struct noise_static_identity static_identity;
struct workqueue_struct *handshake_receive_wq, *handshake_send_wq; struct workqueue_struct *handshake_receive_wq, *handshake_send_wq;
struct workqueue_struct *packet_crypt_wq; struct workqueue_struct *packet_crypt_wq;
...@@ -56,7 +56,6 @@ struct wg_device { ...@@ -56,7 +56,6 @@ struct wg_device {
unsigned int num_peers, device_update_gen; unsigned int num_peers, device_update_gen;
u32 fwmark; u32 fwmark;
u16 incoming_port; u16 incoming_port;
bool have_creating_net_ref;
}; };
int wg_device_init(void); int wg_device_init(void);
......
...@@ -511,11 +511,15 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info) ...@@ -511,11 +511,15 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
if (flags & ~__WGDEVICE_F_ALL) if (flags & ~__WGDEVICE_F_ALL)
goto out; goto out;
ret = -EPERM; if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
if ((info->attrs[WGDEVICE_A_LISTEN_PORT] || struct net *net;
info->attrs[WGDEVICE_A_FWMARK]) && rcu_read_lock();
!ns_capable(wg->creating_net->user_ns, CAP_NET_ADMIN)) net = rcu_dereference(wg->creating_net);
goto out; ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
rcu_read_unlock();
if (ret)
goto out;
}
++wg->device_update_gen; ++wg->device_update_gen;
......
...@@ -617,8 +617,8 @@ wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src, ...@@ -617,8 +617,8 @@ wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
memcpy(handshake->hash, hash, NOISE_HASH_LEN); memcpy(handshake->hash, hash, NOISE_HASH_LEN);
memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN); memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
handshake->remote_index = src->sender_index; handshake->remote_index = src->sender_index;
if ((s64)(handshake->last_initiation_consumption - initiation_consumption = ktime_get_coarse_boottime_ns();
(initiation_consumption = ktime_get_coarse_boottime_ns())) < 0) if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
handshake->last_initiation_consumption = initiation_consumption; handshake->last_initiation_consumption = initiation_consumption;
handshake->state = HANDSHAKE_CONSUMED_INITIATION; handshake->state = HANDSHAKE_CONSUMED_INITIATION;
up_write(&handshake->lock); up_write(&handshake->lock);
......
...@@ -347,6 +347,7 @@ static void set_sock_opts(struct socket *sock) ...@@ -347,6 +347,7 @@ static void set_sock_opts(struct socket *sock)
int wg_socket_init(struct wg_device *wg, u16 port) int wg_socket_init(struct wg_device *wg, u16 port)
{ {
struct net *net;
int ret; int ret;
struct udp_tunnel_sock_cfg cfg = { struct udp_tunnel_sock_cfg cfg = {
.sk_user_data = wg, .sk_user_data = wg,
...@@ -371,37 +372,47 @@ int wg_socket_init(struct wg_device *wg, u16 port) ...@@ -371,37 +372,47 @@ int wg_socket_init(struct wg_device *wg, u16 port)
}; };
#endif #endif
rcu_read_lock();
net = rcu_dereference(wg->creating_net);
net = net ? maybe_get_net(net) : NULL;
rcu_read_unlock();
if (unlikely(!net))
return -ENONET;
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
retry: retry:
#endif #endif
ret = udp_sock_create(wg->creating_net, &port4, &new4); ret = udp_sock_create(net, &port4, &new4);
if (ret < 0) { if (ret < 0) {
pr_err("%s: Could not create IPv4 socket\n", wg->dev->name); pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
return ret; goto out;
} }
set_sock_opts(new4); set_sock_opts(new4);
setup_udp_tunnel_sock(wg->creating_net, new4, &cfg); setup_udp_tunnel_sock(net, new4, &cfg);
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
if (ipv6_mod_enabled()) { if (ipv6_mod_enabled()) {
port6.local_udp_port = inet_sk(new4->sk)->inet_sport; port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
ret = udp_sock_create(wg->creating_net, &port6, &new6); ret = udp_sock_create(net, &port6, &new6);
if (ret < 0) { if (ret < 0) {
udp_tunnel_sock_release(new4); udp_tunnel_sock_release(new4);
if (ret == -EADDRINUSE && !port && retries++ < 100) if (ret == -EADDRINUSE && !port && retries++ < 100)
goto retry; goto retry;
pr_err("%s: Could not create IPv6 socket\n", pr_err("%s: Could not create IPv6 socket\n",
wg->dev->name); wg->dev->name);
return ret; goto out;
} }
set_sock_opts(new6); set_sock_opts(new6);
setup_udp_tunnel_sock(wg->creating_net, new6, &cfg); setup_udp_tunnel_sock(net, new6, &cfg);
} }
#endif #endif
wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL); wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL);
return 0; ret = 0;
out:
put_net(net);
return ret;
} }
void wg_socket_reinit(struct wg_device *wg, struct sock *new4, void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
......
...@@ -587,9 +587,20 @@ ip0 link set wg0 up ...@@ -587,9 +587,20 @@ ip0 link set wg0 up
kill $ncat_pid kill $ncat_pid
ip0 link del wg0 ip0 link del wg0
# Ensure there aren't circular reference loops
ip1 link add wg1 type wireguard
ip2 link add wg2 type wireguard
ip1 link set wg1 netns $netns2
ip2 link set wg2 netns $netns1
pp ip netns delete $netns1
pp ip netns delete $netns2
pp ip netns add $netns1
pp ip netns add $netns2
sleep 2 # Wait for cleanup and grace periods
declare -A objects declare -A objects
while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ ?[0-9]*)\ .*(created|destroyed).* ]] || continue
objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}" objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
done < /dev/kmsg done < /dev/kmsg
alldeleted=1 alldeleted=1
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment