Commit 24256415 authored by James Chapman's avatar James Chapman Committed by David S. Miller

l2tp: prevent possible tunnel refcount underflow

When a session is created, it sets a backpointer to its tunnel. When
the session refcount drops to 0, l2tp_session_free drops the tunnel
refcount if session->tunnel is non-NULL. However, session->tunnel is
set in l2tp_session_create, before the tunnel refcount is incremented
by l2tp_session_register, which leaves a small window where
session->tunnel is non-NULL when the tunnel refcount hasn't been
bumped.

Moving the assignment to l2tp_session_register is trivial but
l2tp_session_create calls l2tp_session_set_header_len which uses
session->tunnel to get the tunnel's encap. Add an encap arg to
l2tp_session_set_header_len to avoid using session->tunnel.

If l2tpv3 sessions have colliding IDs, it is possible for
l2tp_v3_session_get to race with l2tp_session_register and fetch a
session which doesn't yet have session->tunnel set. Add a check for
this case.
Signed-off-by: default avatarJames Chapman <jchapman@katalix.com>
Signed-off-by: default avatarTom Parkin <tparkin@katalix.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent c5cbaef9
...@@ -279,7 +279,14 @@ struct l2tp_session *l2tp_v3_session_get(const struct net *net, struct sock *sk, ...@@ -279,7 +279,14 @@ struct l2tp_session *l2tp_v3_session_get(const struct net *net, struct sock *sk,
hash_for_each_possible_rcu(pn->l2tp_v3_session_htable, session, hash_for_each_possible_rcu(pn->l2tp_v3_session_htable, session,
hlist, key) { hlist, key) {
if (session->tunnel->sock == sk && /* session->tunnel may be NULL if another thread is in
* l2tp_session_register and has added an item to
* l2tp_v3_session_htable but hasn't yet added the
* session to its tunnel's session_list.
*/
struct l2tp_tunnel *tunnel = READ_ONCE(session->tunnel);
if (tunnel && tunnel->sock == sk &&
refcount_inc_not_zero(&session->ref_count)) { refcount_inc_not_zero(&session->ref_count)) {
rcu_read_unlock_bh(); rcu_read_unlock_bh();
return session; return session;
...@@ -507,6 +514,7 @@ int l2tp_session_register(struct l2tp_session *session, ...@@ -507,6 +514,7 @@ int l2tp_session_register(struct l2tp_session *session,
} }
l2tp_tunnel_inc_refcount(tunnel); l2tp_tunnel_inc_refcount(tunnel);
WRITE_ONCE(session->tunnel, tunnel);
list_add(&session->list, &tunnel->session_list); list_add(&session->list, &tunnel->session_list);
if (tunnel->version == L2TP_HDR_VER_3) { if (tunnel->version == L2TP_HDR_VER_3) {
...@@ -822,7 +830,8 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, ...@@ -822,7 +830,8 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
if (!session->lns_mode && !session->send_seq) { if (!session->lns_mode && !session->send_seq) {
trace_session_seqnum_lns_enable(session); trace_session_seqnum_lns_enable(session);
session->send_seq = 1; session->send_seq = 1;
l2tp_session_set_header_len(session, tunnel->version); l2tp_session_set_header_len(session, tunnel->version,
tunnel->encap);
} }
} else { } else {
/* No sequence numbers. /* No sequence numbers.
...@@ -843,7 +852,8 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, ...@@ -843,7 +852,8 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
if (!session->lns_mode && session->send_seq) { if (!session->lns_mode && session->send_seq) {
trace_session_seqnum_lns_disable(session); trace_session_seqnum_lns_disable(session);
session->send_seq = 0; session->send_seq = 0;
l2tp_session_set_header_len(session, tunnel->version); l2tp_session_set_header_len(session, tunnel->version,
tunnel->encap);
} else if (session->send_seq) { } else if (session->send_seq) {
pr_debug_ratelimited("%s: recv data has no seq numbers when required. Discarding.\n", pr_debug_ratelimited("%s: recv data has no seq numbers when required. Discarding.\n",
session->name); session->name);
...@@ -1648,7 +1658,8 @@ static void l2tp_session_del_work(struct work_struct *work) ...@@ -1648,7 +1658,8 @@ static void l2tp_session_del_work(struct work_struct *work)
/* We come here whenever a session's send_seq, cookie_len or /* We come here whenever a session's send_seq, cookie_len or
* l2specific_type parameters are set. * l2specific_type parameters are set.
*/ */
void l2tp_session_set_header_len(struct l2tp_session *session, int version) void l2tp_session_set_header_len(struct l2tp_session *session, int version,
enum l2tp_encap_type encap)
{ {
if (version == L2TP_HDR_VER_2) { if (version == L2TP_HDR_VER_2) {
session->hdr_len = 6; session->hdr_len = 6;
...@@ -1657,7 +1668,7 @@ void l2tp_session_set_header_len(struct l2tp_session *session, int version) ...@@ -1657,7 +1668,7 @@ void l2tp_session_set_header_len(struct l2tp_session *session, int version)
} else { } else {
session->hdr_len = 4 + session->cookie_len; session->hdr_len = 4 + session->cookie_len;
session->hdr_len += l2tp_get_l2specific_len(session); session->hdr_len += l2tp_get_l2specific_len(session);
if (session->tunnel->encap == L2TP_ENCAPTYPE_UDP) if (encap == L2TP_ENCAPTYPE_UDP)
session->hdr_len += 4; session->hdr_len += 4;
} }
} }
...@@ -1671,7 +1682,6 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn ...@@ -1671,7 +1682,6 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
session = kzalloc(sizeof(*session) + priv_size, GFP_KERNEL); session = kzalloc(sizeof(*session) + priv_size, GFP_KERNEL);
if (session) { if (session) {
session->magic = L2TP_SESSION_MAGIC; session->magic = L2TP_SESSION_MAGIC;
session->tunnel = tunnel;
session->session_id = session_id; session->session_id = session_id;
session->peer_session_id = peer_session_id; session->peer_session_id = peer_session_id;
...@@ -1710,7 +1720,7 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn ...@@ -1710,7 +1720,7 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
memcpy(&session->peer_cookie[0], &cfg->peer_cookie[0], cfg->peer_cookie_len); memcpy(&session->peer_cookie[0], &cfg->peer_cookie[0], cfg->peer_cookie_len);
} }
l2tp_session_set_header_len(session, tunnel->version); l2tp_session_set_header_len(session, tunnel->version, tunnel->encap);
refcount_set(&session->ref_count, 1); refcount_set(&session->ref_count, 1);
......
...@@ -258,7 +258,8 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, ...@@ -258,7 +258,8 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff *skb); int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff *skb);
/* Transmit path helpers for sending packets over the tunnel socket. */ /* Transmit path helpers for sending packets over the tunnel socket. */
void l2tp_session_set_header_len(struct l2tp_session *session, int version); void l2tp_session_set_header_len(struct l2tp_session *session, int version,
enum l2tp_encap_type encap);
int l2tp_xmit_skb(struct l2tp_session *session, struct sk_buff *skb); int l2tp_xmit_skb(struct l2tp_session *session, struct sk_buff *skb);
/* Pseudowire management. /* Pseudowire management.
......
...@@ -692,8 +692,10 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf ...@@ -692,8 +692,10 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf
session->recv_seq = nla_get_u8(info->attrs[L2TP_ATTR_RECV_SEQ]); session->recv_seq = nla_get_u8(info->attrs[L2TP_ATTR_RECV_SEQ]);
if (info->attrs[L2TP_ATTR_SEND_SEQ]) { if (info->attrs[L2TP_ATTR_SEND_SEQ]) {
struct l2tp_tunnel *tunnel = session->tunnel;
session->send_seq = nla_get_u8(info->attrs[L2TP_ATTR_SEND_SEQ]); session->send_seq = nla_get_u8(info->attrs[L2TP_ATTR_SEND_SEQ]);
l2tp_session_set_header_len(session, session->tunnel->version); l2tp_session_set_header_len(session, tunnel->version, tunnel->encap);
} }
if (info->attrs[L2TP_ATTR_LNS_MODE]) if (info->attrs[L2TP_ATTR_LNS_MODE])
......
...@@ -1189,7 +1189,8 @@ static int pppol2tp_session_setsockopt(struct sock *sk, ...@@ -1189,7 +1189,8 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
po->chan.hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ : po->chan.hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ :
PPPOL2TP_L2TP_HDR_SIZE_NOSEQ; PPPOL2TP_L2TP_HDR_SIZE_NOSEQ;
} }
l2tp_session_set_header_len(session, session->tunnel->version); l2tp_session_set_header_len(session, session->tunnel->version,
session->tunnel->encap);
break; break;
case PPPOL2TP_SO_LNSMODE: case PPPOL2TP_SO_LNSMODE:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment