Commit 443cb3a3 authored by David S. Miller's avatar David S. Miller

Merge branch 'l2tp-session-creation-fixes'

Guillaume Nault says:

====================
l2tp: session creation fixes

The session creation process has a few issues wrt. concurrent tunnel
deletion.

Patch #1 avoids creating sessions in tunnels that are getting removed.
This prevents races where sessions could try to take tunnel resources
that were already released.

Patch #2 removes some racy l2tp_tunnel_find() calls in session creation
callbacks. Together with path #1 it ensures that sessions can only
access tunnel resources that are guaranteed to remain valid during the
session creation process.

There are other problems with how sessions are created: pseudo-wire
specific data are set after the session is added to the tunnel. So
the session can be used, or deleted, before it has been completely
initialised. Separating session allocation from session registration
would be necessary, but we'd still have circular dependencies
preventing race-free registration. I'll consider this issue in future
series.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 4113f36b f026bc29
...@@ -329,13 +329,21 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel, ...@@ -329,13 +329,21 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
struct hlist_head *g_head; struct hlist_head *g_head;
struct hlist_head *head; struct hlist_head *head;
struct l2tp_net *pn; struct l2tp_net *pn;
int err;
head = l2tp_session_id_hash(tunnel, session->session_id); head = l2tp_session_id_hash(tunnel, session->session_id);
write_lock_bh(&tunnel->hlist_lock); write_lock_bh(&tunnel->hlist_lock);
if (!tunnel->acpt_newsess) {
err = -ENODEV;
goto err_tlock;
}
hlist_for_each_entry(session_walk, head, hlist) hlist_for_each_entry(session_walk, head, hlist)
if (session_walk->session_id == session->session_id) if (session_walk->session_id == session->session_id) {
goto exist; err = -EEXIST;
goto err_tlock;
}
if (tunnel->version == L2TP_HDR_VER_3) { if (tunnel->version == L2TP_HDR_VER_3) {
pn = l2tp_pernet(tunnel->l2tp_net); pn = l2tp_pernet(tunnel->l2tp_net);
...@@ -343,12 +351,21 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel, ...@@ -343,12 +351,21 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
session->session_id); session->session_id);
spin_lock_bh(&pn->l2tp_session_hlist_lock); spin_lock_bh(&pn->l2tp_session_hlist_lock);
hlist_for_each_entry(session_walk, g_head, global_hlist) hlist_for_each_entry(session_walk, g_head, global_hlist)
if (session_walk->session_id == session->session_id) if (session_walk->session_id == session->session_id) {
goto exist_glob; err = -EEXIST;
goto err_tlock_pnlock;
}
l2tp_tunnel_inc_refcount(tunnel);
sock_hold(tunnel->sock);
hlist_add_head_rcu(&session->global_hlist, g_head); hlist_add_head_rcu(&session->global_hlist, g_head);
spin_unlock_bh(&pn->l2tp_session_hlist_lock); spin_unlock_bh(&pn->l2tp_session_hlist_lock);
} else {
l2tp_tunnel_inc_refcount(tunnel);
sock_hold(tunnel->sock);
} }
hlist_add_head(&session->hlist, head); hlist_add_head(&session->hlist, head);
...@@ -356,12 +373,12 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel, ...@@ -356,12 +373,12 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
return 0; return 0;
exist_glob: err_tlock_pnlock:
spin_unlock_bh(&pn->l2tp_session_hlist_lock); spin_unlock_bh(&pn->l2tp_session_hlist_lock);
exist: err_tlock:
write_unlock_bh(&tunnel->hlist_lock); write_unlock_bh(&tunnel->hlist_lock);
return -EEXIST; return err;
} }
/* Lookup a tunnel by id /* Lookup a tunnel by id
...@@ -1251,7 +1268,6 @@ static void l2tp_tunnel_destruct(struct sock *sk) ...@@ -1251,7 +1268,6 @@ static void l2tp_tunnel_destruct(struct sock *sk)
/* Remove hooks into tunnel socket */ /* Remove hooks into tunnel socket */
sk->sk_destruct = tunnel->old_sk_destruct; sk->sk_destruct = tunnel->old_sk_destruct;
sk->sk_user_data = NULL; sk->sk_user_data = NULL;
tunnel->sock = NULL;
/* Remove the tunnel struct from the tunnel list */ /* Remove the tunnel struct from the tunnel list */
pn = l2tp_pernet(tunnel->l2tp_net); pn = l2tp_pernet(tunnel->l2tp_net);
...@@ -1261,6 +1277,8 @@ static void l2tp_tunnel_destruct(struct sock *sk) ...@@ -1261,6 +1277,8 @@ static void l2tp_tunnel_destruct(struct sock *sk)
atomic_dec(&l2tp_tunnel_count); atomic_dec(&l2tp_tunnel_count);
l2tp_tunnel_closeall(tunnel); l2tp_tunnel_closeall(tunnel);
tunnel->sock = NULL;
l2tp_tunnel_dec_refcount(tunnel); l2tp_tunnel_dec_refcount(tunnel);
/* Call the original destructor */ /* Call the original destructor */
...@@ -1285,6 +1303,7 @@ void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel) ...@@ -1285,6 +1303,7 @@ void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel)
tunnel->name); tunnel->name);
write_lock_bh(&tunnel->hlist_lock); write_lock_bh(&tunnel->hlist_lock);
tunnel->acpt_newsess = false;
for (hash = 0; hash < L2TP_HASH_SIZE; hash++) { for (hash = 0; hash < L2TP_HASH_SIZE; hash++) {
again: again:
hlist_for_each_safe(walk, tmp, &tunnel->session_hlist[hash]) { hlist_for_each_safe(walk, tmp, &tunnel->session_hlist[hash]) {
...@@ -1581,6 +1600,7 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 ...@@ -1581,6 +1600,7 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
tunnel->magic = L2TP_TUNNEL_MAGIC; tunnel->magic = L2TP_TUNNEL_MAGIC;
sprintf(&tunnel->name[0], "tunl %u", tunnel_id); sprintf(&tunnel->name[0], "tunl %u", tunnel_id);
rwlock_init(&tunnel->hlist_lock); rwlock_init(&tunnel->hlist_lock);
tunnel->acpt_newsess = true;
/* The net we belong to */ /* The net we belong to */
tunnel->l2tp_net = net; tunnel->l2tp_net = net;
...@@ -1829,11 +1849,6 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn ...@@ -1829,11 +1849,6 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
return ERR_PTR(err); return ERR_PTR(err);
} }
l2tp_tunnel_inc_refcount(tunnel);
/* Ensure tunnel socket isn't deleted */
sock_hold(tunnel->sock);
/* Ignore management session in session count value */ /* Ignore management session in session count value */
if (session->session_id != 0) if (session->session_id != 0)
atomic_inc(&l2tp_session_count); atomic_inc(&l2tp_session_count);
......
...@@ -162,6 +162,10 @@ struct l2tp_tunnel { ...@@ -162,6 +162,10 @@ struct l2tp_tunnel {
int magic; /* Should be L2TP_TUNNEL_MAGIC */ int magic; /* Should be L2TP_TUNNEL_MAGIC */
struct rcu_head rcu; struct rcu_head rcu;
rwlock_t hlist_lock; /* protect session_hlist */ rwlock_t hlist_lock; /* protect session_hlist */
bool acpt_newsess; /* Indicates whether this
* tunnel accepts new sessions.
* Protected by hlist_lock.
*/
struct hlist_head session_hlist[L2TP_HASH_SIZE]; struct hlist_head session_hlist[L2TP_HASH_SIZE];
/* hashed list of sessions, /* hashed list of sessions,
* hashed by id */ * hashed by id */
...@@ -197,7 +201,9 @@ struct l2tp_tunnel { ...@@ -197,7 +201,9 @@ struct l2tp_tunnel {
}; };
struct l2tp_nl_cmd_ops { struct l2tp_nl_cmd_ops {
int (*session_create)(struct net *net, u32 tunnel_id, u32 session_id, u32 peer_session_id, struct l2tp_session_cfg *cfg); int (*session_create)(struct net *net, struct l2tp_tunnel *tunnel,
u32 session_id, u32 peer_session_id,
struct l2tp_session_cfg *cfg);
int (*session_delete)(struct l2tp_session *session); int (*session_delete)(struct l2tp_session *session);
}; };
......
...@@ -262,24 +262,19 @@ static void l2tp_eth_adjust_mtu(struct l2tp_tunnel *tunnel, ...@@ -262,24 +262,19 @@ static void l2tp_eth_adjust_mtu(struct l2tp_tunnel *tunnel,
dev->needed_headroom += session->hdr_len; dev->needed_headroom += session->hdr_len;
} }
static int l2tp_eth_create(struct net *net, u32 tunnel_id, u32 session_id, u32 peer_session_id, struct l2tp_session_cfg *cfg) static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel,
u32 session_id, u32 peer_session_id,
struct l2tp_session_cfg *cfg)
{ {
unsigned char name_assign_type; unsigned char name_assign_type;
struct net_device *dev; struct net_device *dev;
char name[IFNAMSIZ]; char name[IFNAMSIZ];
struct l2tp_tunnel *tunnel;
struct l2tp_session *session; struct l2tp_session *session;
struct l2tp_eth *priv; struct l2tp_eth *priv;
struct l2tp_eth_sess *spriv; struct l2tp_eth_sess *spriv;
int rc; int rc;
struct l2tp_eth_net *pn; struct l2tp_eth_net *pn;
tunnel = l2tp_tunnel_find(net, tunnel_id);
if (!tunnel) {
rc = -ENODEV;
goto out;
}
if (cfg->ifname) { if (cfg->ifname) {
strlcpy(name, cfg->ifname, IFNAMSIZ); strlcpy(name, cfg->ifname, IFNAMSIZ);
name_assign_type = NET_NAME_USER; name_assign_type = NET_NAME_USER;
......
...@@ -643,10 +643,10 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf ...@@ -643,10 +643,10 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
break; break;
} }
ret = -EPROTONOSUPPORT; ret = l2tp_nl_cmd_ops[cfg.pw_type]->session_create(net, tunnel,
if (l2tp_nl_cmd_ops[cfg.pw_type]->session_create) session_id,
ret = (*l2tp_nl_cmd_ops[cfg.pw_type]->session_create)(net, tunnel_id, peer_session_id,
session_id, peer_session_id, &cfg); &cfg);
if (ret >= 0) { if (ret >= 0) {
session = l2tp_session_get(net, tunnel, session_id, false); session = l2tp_session_get(net, tunnel, session_id, false);
......
...@@ -788,25 +788,20 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, ...@@ -788,25 +788,20 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
#ifdef CONFIG_L2TP_V3 #ifdef CONFIG_L2TP_V3
/* Called when creating sessions via the netlink interface. /* Called when creating sessions via the netlink interface. */
*/ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
static int pppol2tp_session_create(struct net *net, u32 tunnel_id, u32 session_id, u32 peer_session_id, struct l2tp_session_cfg *cfg) u32 session_id, u32 peer_session_id,
struct l2tp_session_cfg *cfg)
{ {
int error; int error;
struct l2tp_tunnel *tunnel;
struct l2tp_session *session; struct l2tp_session *session;
struct pppol2tp_session *ps; struct pppol2tp_session *ps;
tunnel = l2tp_tunnel_find(net, tunnel_id);
/* Error if we can't find the tunnel */
error = -ENOENT;
if (tunnel == NULL)
goto out;
/* Error if tunnel socket is not prepped */ /* Error if tunnel socket is not prepped */
if (tunnel->sock == NULL) if (!tunnel->sock) {
error = -ENOENT;
goto out; goto out;
}
/* Default MTU values. */ /* Default MTU values. */
if (cfg->mtu == 0) if (cfg->mtu == 0)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment