Commit 63a763d0 authored by Guillaume Nault's avatar Guillaume Nault Committed by Ben Hutchings

l2tp: fix race in l2tp_recv_common()

commit 61b9a047 upstream.

Taking a reference on sessions in l2tp_recv_common() is racy; this
has to be done by the callers.

To this end, a new function is required (l2tp_session_get()) to
atomically lookup a session and take a reference on it. Callers then
have to manually drop this reference.

Fixes: fd558d18 ("l2tp: Split pppol2tp patch into separate l2tp and ppp parts")
Signed-off-by: default avatarGuillaume Nault <g.nault@alphalink.fr>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
[bwh: Backported to 3.2:
 - Drop changes to l2tp_ip6.c
 - Add 'pos' parameter to hlist_for_each_entry{,_rcu}() calls
 - Adjust context]
Signed-off-by: default avatarBen Hutchings <ben@decadent.org.uk>
parent 2a14908e
...@@ -225,6 +225,56 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn ...@@ -225,6 +225,56 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
} }
EXPORT_SYMBOL_GPL(l2tp_session_find); EXPORT_SYMBOL_GPL(l2tp_session_find);
/* Like l2tp_session_find() but takes a reference on the returned session.
* Optionally calls session->ref() too if do_ref is true.
*/
struct l2tp_session *l2tp_session_get(struct net *net,
struct l2tp_tunnel *tunnel,
u32 session_id, bool do_ref)
{
struct hlist_head *session_list;
struct l2tp_session *session;
struct hlist_node *walk;
if (!tunnel) {
struct l2tp_net *pn = l2tp_pernet(net);
session_list = l2tp_session_id_hash_2(pn, session_id);
rcu_read_lock_bh();
hlist_for_each_entry_rcu(session, walk, session_list, global_hlist) {
if (session->session_id == session_id) {
l2tp_session_inc_refcount(session);
if (do_ref && session->ref)
session->ref(session);
rcu_read_unlock_bh();
return session;
}
}
rcu_read_unlock_bh();
return NULL;
}
session_list = l2tp_session_id_hash(tunnel, session_id);
read_lock_bh(&tunnel->hlist_lock);
hlist_for_each_entry(session, walk, session_list, hlist) {
if (session->session_id == session_id) {
l2tp_session_inc_refcount(session);
if (do_ref && session->ref)
session->ref(session);
read_unlock_bh(&tunnel->hlist_lock);
return session;
}
}
read_unlock_bh(&tunnel->hlist_lock);
return NULL;
}
EXPORT_SYMBOL_GPL(l2tp_session_get);
struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth) struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
{ {
int hash; int hash;
...@@ -524,6 +574,9 @@ static inline int l2tp_verify_udp_checksum(struct sock *sk, ...@@ -524,6 +574,9 @@ static inline int l2tp_verify_udp_checksum(struct sock *sk,
* a data (not control) frame before coming here. Fields up to the * a data (not control) frame before coming here. Fields up to the
* session-id have already been parsed and ptr points to the data * session-id have already been parsed and ptr points to the data
* after the session-id. * after the session-id.
*
* session->ref() must have been called prior to l2tp_recv_common().
* session->deref() will be called automatically after skb is processed.
*/ */
void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
unsigned char *ptr, unsigned char *optr, u16 hdrflags, unsigned char *ptr, unsigned char *optr, u16 hdrflags,
...@@ -533,14 +586,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, ...@@ -533,14 +586,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
int offset; int offset;
u32 ns, nr; u32 ns, nr;
/* The ref count is increased since we now hold a pointer to
* the session. Take care to decrement the refcnt when exiting
* this function from now on...
*/
l2tp_session_inc_refcount(session);
if (session->ref)
(*session->ref)(session);
/* Parse and check optional cookie */ /* Parse and check optional cookie */
if (session->peer_cookie_len > 0) { if (session->peer_cookie_len > 0) {
if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) { if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) {
...@@ -711,8 +756,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, ...@@ -711,8 +756,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
/* Try to dequeue as many skbs from reorder_q as we can. */ /* Try to dequeue as many skbs from reorder_q as we can. */
l2tp_recv_dequeue(session); l2tp_recv_dequeue(session);
l2tp_session_dec_refcount(session);
return; return;
discard: discard:
...@@ -721,8 +764,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, ...@@ -721,8 +764,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
if (session->deref) if (session->deref)
(*session->deref)(session); (*session->deref)(session);
l2tp_session_dec_refcount(session);
} }
EXPORT_SYMBOL(l2tp_recv_common); EXPORT_SYMBOL(l2tp_recv_common);
...@@ -818,8 +859,14 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb, ...@@ -818,8 +859,14 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
} }
/* Find the session context */ /* Find the session context */
session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id); session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
if (!session || !session->recv_skb) { if (!session || !session->recv_skb) {
if (session) {
if (session->deref)
session->deref(session);
l2tp_session_dec_refcount(session);
}
/* Not found? Pass to userspace to deal with */ /* Not found? Pass to userspace to deal with */
PRINTK(tunnel->debug, L2TP_MSG_DATA, KERN_INFO, PRINTK(tunnel->debug, L2TP_MSG_DATA, KERN_INFO,
"%s: no session found (%u/%u). Passing up.\n", "%s: no session found (%u/%u). Passing up.\n",
...@@ -828,6 +875,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb, ...@@ -828,6 +875,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
} }
l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook); l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook);
l2tp_session_dec_refcount(session);
return 0; return 0;
......
...@@ -222,6 +222,9 @@ static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk) ...@@ -222,6 +222,9 @@ static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk)
return tunnel; return tunnel;
} }
struct l2tp_session *l2tp_session_get(struct net *net,
struct l2tp_tunnel *tunnel,
u32 session_id, bool do_ref);
extern struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunnel, u32 session_id); extern struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunnel, u32 session_id);
extern struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth); extern struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth);
extern struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname); extern struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname);
......
...@@ -148,19 +148,19 @@ static int l2tp_ip_recv(struct sk_buff *skb) ...@@ -148,19 +148,19 @@ static int l2tp_ip_recv(struct sk_buff *skb)
} }
/* Ok, this is a data packet. Lookup the session. */ /* Ok, this is a data packet. Lookup the session. */
session = l2tp_session_find(&init_net, NULL, session_id); session = l2tp_session_get(&init_net, NULL, session_id, true);
if (session == NULL) if (!session)
goto discard; goto discard;
tunnel = session->tunnel; tunnel = session->tunnel;
if (tunnel == NULL) if (!tunnel)
goto discard; goto discard_sess;
/* Trace packet contents, if enabled */ /* Trace packet contents, if enabled */
if (tunnel->debug & L2TP_MSG_DATA) { if (tunnel->debug & L2TP_MSG_DATA) {
length = min(32u, skb->len); length = min(32u, skb->len);
if (!pskb_may_pull(skb, length)) if (!pskb_may_pull(skb, length))
goto discard; goto discard_sess;
/* Point to L2TP header */ /* Point to L2TP header */
optr = ptr = skb->data; optr = ptr = skb->data;
...@@ -176,6 +176,7 @@ static int l2tp_ip_recv(struct sk_buff *skb) ...@@ -176,6 +176,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
} }
l2tp_recv_common(session, skb, ptr, optr, 0, skb->len, tunnel->recv_payload_hook); l2tp_recv_common(session, skb, ptr, optr, 0, skb->len, tunnel->recv_payload_hook);
l2tp_session_dec_refcount(session);
return 0; return 0;
...@@ -211,6 +212,12 @@ static int l2tp_ip_recv(struct sk_buff *skb) ...@@ -211,6 +212,12 @@ static int l2tp_ip_recv(struct sk_buff *skb)
return sk_receive_skb(sk, skb, 1); return sk_receive_skb(sk, skb, 1);
discard_sess:
if (session->deref)
session->deref(session);
l2tp_session_dec_refcount(session);
goto discard;
discard_put: discard_put:
sock_put(sk); sock_put(sk);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment