Commit b301c9b7 authored by Guillaume Nault's avatar Guillaume Nault Committed by Ben Hutchings

l2tp: take a reference on sessions used in genetlink handlers

commit 2777e2ab upstream.

Callers of l2tp_nl_session_find() need to hold a reference on the
returned session since there's no guarantee that it isn't going to
disappear from under them.

Relying on the fact that no l2tp netlink message may be processed
concurrently isn't enough: sessions can be deleted by other means
(e.g. by closing the PPPOL2TP socket of a ppp pseudowire).

l2tp_nl_cmd_session_delete() is a bit special: it runs a callback
function that may require a previous call to session->ref(). In
particular, for ppp pseudowires, the callback is l2tp_session_delete(),
which then calls pppol2tp_session_close() and dereferences the PPPOL2TP
socket. The socket might already be gone at the moment
l2tp_session_delete() calls session->ref(), so we need to take a
reference during the session lookup. So we need to pass the do_ref
variable down to l2tp_session_get() and l2tp_session_get_by_ifname().

Since all callers have to be updated, l2tp_session_find_by_ifname() and
l2tp_nl_session_find() are renamed to reflect their new behaviour.

Fixes: 309795f4 ("l2tp: Add netlink control API for L2TP")
Signed-off-by: default avatarGuillaume Nault <g.nault@alphalink.fr>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
[bwh: Backported to 3.16: adjust context]
Signed-off-by: default avatarBen Hutchings <ben@decadent.org.uk>
parent 877dc0d7
...@@ -351,7 +351,8 @@ EXPORT_SYMBOL_GPL(l2tp_session_find_nth); ...@@ -351,7 +351,8 @@ EXPORT_SYMBOL_GPL(l2tp_session_find_nth);
/* Lookup a session by interface name. /* Lookup a session by interface name.
* This is very inefficient but is only used by management interfaces. * This is very inefficient but is only used by management interfaces.
*/ */
struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname) struct l2tp_session *l2tp_session_get_by_ifname(struct net *net, char *ifname,
bool do_ref)
{ {
struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_net *pn = l2tp_pernet(net);
int hash; int hash;
...@@ -361,7 +362,11 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname) ...@@ -361,7 +362,11 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) { for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) {
hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) { hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) {
if (!strcmp(session->ifname, ifname)) { if (!strcmp(session->ifname, ifname)) {
l2tp_session_inc_refcount(session);
if (do_ref && session->ref)
session->ref(session);
rcu_read_unlock_bh(); rcu_read_unlock_bh();
return session; return session;
} }
} }
...@@ -371,7 +376,7 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname) ...@@ -371,7 +376,7 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
return NULL; return NULL;
} }
EXPORT_SYMBOL_GPL(l2tp_session_find_by_ifname); EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname);
static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel, static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
struct l2tp_session *session) struct l2tp_session *session)
......
...@@ -247,7 +247,8 @@ struct l2tp_session *l2tp_session_find(struct net *net, ...@@ -247,7 +247,8 @@ struct l2tp_session *l2tp_session_find(struct net *net,
struct l2tp_tunnel *tunnel, struct l2tp_tunnel *tunnel,
u32 session_id); u32 session_id);
struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth); struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth);
struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname); struct l2tp_session *l2tp_session_get_by_ifname(struct net *net, char *ifname,
bool do_ref);
struct l2tp_tunnel *l2tp_tunnel_find(struct net *net, u32 tunnel_id); struct l2tp_tunnel *l2tp_tunnel_find(struct net *net, u32 tunnel_id);
struct l2tp_tunnel *l2tp_tunnel_find_nth(struct net *net, int nth); struct l2tp_tunnel *l2tp_tunnel_find_nth(struct net *net, int nth);
......
...@@ -43,7 +43,8 @@ static struct genl_family l2tp_nl_family = { ...@@ -43,7 +43,8 @@ static struct genl_family l2tp_nl_family = {
/* Accessed under genl lock */ /* Accessed under genl lock */
static const struct l2tp_nl_cmd_ops *l2tp_nl_cmd_ops[__L2TP_PWTYPE_MAX]; static const struct l2tp_nl_cmd_ops *l2tp_nl_cmd_ops[__L2TP_PWTYPE_MAX];
static struct l2tp_session *l2tp_nl_session_find(struct genl_info *info) static struct l2tp_session *l2tp_nl_session_get(struct genl_info *info,
bool do_ref)
{ {
u32 tunnel_id; u32 tunnel_id;
u32 session_id; u32 session_id;
...@@ -54,14 +55,15 @@ static struct l2tp_session *l2tp_nl_session_find(struct genl_info *info) ...@@ -54,14 +55,15 @@ static struct l2tp_session *l2tp_nl_session_find(struct genl_info *info)
if (info->attrs[L2TP_ATTR_IFNAME]) { if (info->attrs[L2TP_ATTR_IFNAME]) {
ifname = nla_data(info->attrs[L2TP_ATTR_IFNAME]); ifname = nla_data(info->attrs[L2TP_ATTR_IFNAME]);
session = l2tp_session_find_by_ifname(net, ifname); session = l2tp_session_get_by_ifname(net, ifname, do_ref);
} else if ((info->attrs[L2TP_ATTR_SESSION_ID]) && } else if ((info->attrs[L2TP_ATTR_SESSION_ID]) &&
(info->attrs[L2TP_ATTR_CONN_ID])) { (info->attrs[L2TP_ATTR_CONN_ID])) {
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]); session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]);
tunnel = l2tp_tunnel_find(net, tunnel_id); tunnel = l2tp_tunnel_find(net, tunnel_id);
if (tunnel) if (tunnel)
session = l2tp_session_find(net, tunnel, session_id); session = l2tp_session_get(net, tunnel, session_id,
do_ref);
} }
return session; return session;
...@@ -549,7 +551,7 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf ...@@ -549,7 +551,7 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf
struct l2tp_session *session; struct l2tp_session *session;
u16 pw_type; u16 pw_type;
session = l2tp_nl_session_find(info); session = l2tp_nl_session_get(info, true);
if (session == NULL) { if (session == NULL) {
ret = -ENODEV; ret = -ENODEV;
goto out; goto out;
...@@ -560,6 +562,10 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf ...@@ -560,6 +562,10 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf
if (l2tp_nl_cmd_ops[pw_type] && l2tp_nl_cmd_ops[pw_type]->session_delete) if (l2tp_nl_cmd_ops[pw_type] && l2tp_nl_cmd_ops[pw_type]->session_delete)
ret = (*l2tp_nl_cmd_ops[pw_type]->session_delete)(session); ret = (*l2tp_nl_cmd_ops[pw_type]->session_delete)(session);
if (session->deref)
session->deref(session);
l2tp_session_dec_refcount(session);
out: out:
return ret; return ret;
} }
...@@ -569,7 +575,7 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf ...@@ -569,7 +575,7 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf
int ret = 0; int ret = 0;
struct l2tp_session *session; struct l2tp_session *session;
session = l2tp_nl_session_find(info); session = l2tp_nl_session_get(info, false);
if (session == NULL) { if (session == NULL) {
ret = -ENODEV; ret = -ENODEV;
goto out; goto out;
...@@ -601,6 +607,8 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf ...@@ -601,6 +607,8 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf
if (info->attrs[L2TP_ATTR_MRU]) if (info->attrs[L2TP_ATTR_MRU])
session->mru = nla_get_u16(info->attrs[L2TP_ATTR_MRU]); session->mru = nla_get_u16(info->attrs[L2TP_ATTR_MRU]);
l2tp_session_dec_refcount(session);
out: out:
return ret; return ret;
} }
...@@ -686,29 +694,34 @@ static int l2tp_nl_cmd_session_get(struct sk_buff *skb, struct genl_info *info) ...@@ -686,29 +694,34 @@ static int l2tp_nl_cmd_session_get(struct sk_buff *skb, struct genl_info *info)
struct sk_buff *msg; struct sk_buff *msg;
int ret; int ret;
session = l2tp_nl_session_find(info); session = l2tp_nl_session_get(info, false);
if (session == NULL) { if (session == NULL) {
ret = -ENODEV; ret = -ENODEV;
goto out; goto err;
} }
msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
if (!msg) { if (!msg) {
ret = -ENOMEM; ret = -ENOMEM;
goto out; goto err_ref;
} }
ret = l2tp_nl_session_send(msg, info->snd_portid, info->snd_seq, ret = l2tp_nl_session_send(msg, info->snd_portid, info->snd_seq,
0, session); 0, session);
if (ret < 0) if (ret < 0)
goto err_out; goto err_ref_msg;
return genlmsg_unicast(genl_info_net(info), msg, info->snd_portid); ret = genlmsg_unicast(genl_info_net(info), msg, info->snd_portid);
err_out: l2tp_session_dec_refcount(session);
nlmsg_free(msg);
out: return ret;
err_ref_msg:
nlmsg_free(msg);
err_ref:
l2tp_session_dec_refcount(session);
err:
return ret; return ret;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment