Commit aa5e17e1 authored by James Chapman's avatar James Chapman Committed by David S. Miller

l2tp: store l2tpv3 sessions in per-net IDR

L2TPv3 sessions are currently held in one of two fixed-size hash
lists: either a per-net hashlist (IP-encap), or a per-tunnel hashlist
(UDP-encap), keyed by the L2TPv3 32-bit session_id.

In order to lookup L2TPv3 sessions in UDP-encap tunnels efficiently
without finding the tunnel first via sk_user_data, UDP sessions are
now kept in a per-net session list, keyed by session ID. Convert the
existing per-net hashlist to use an IDR for better performance when
there are many sessions and have L2TPv3 UDP sessions use the same IDR.

Although the L2TPv3 RFC states that the session ID alone identifies
the session, our implementation has allowed the same session ID to be
used in different L2TP UDP tunnels. To retain support for this, a new
per-net session hashtable is used, keyed by the sock and session
ID. If on creating a new session, a session already exists with that
ID in the IDR, the colliding sessions are added to the new hashtable
and the existing IDR entry is flagged. When looking up sessions, the
approach is to first check the IDR and if no unflagged match is found,
check the new hashtable. The sock is made available to session getters
where session ID collisions are to be considered. In this way, the new
hashtable is used only for session ID collisions so can be kept small.

For managing session removal, we need a list of colliding sessions
matching a given ID in order to update or remove the IDR entry of the
ID. This is necessary to detect session ID collisions when future
sessions are created. The list head is allocated on first collision
of a given ID and refcounted.
Signed-off-by: default avatarJames Chapman <jchapman@katalix.com>
Reviewed-by: default avatarTom Parkin <tparkin@katalix.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent a744e2d0
...@@ -107,11 +107,17 @@ struct l2tp_net { ...@@ -107,11 +107,17 @@ struct l2tp_net {
/* Lock for write access to l2tp_tunnel_idr */ /* Lock for write access to l2tp_tunnel_idr */
spinlock_t l2tp_tunnel_idr_lock; spinlock_t l2tp_tunnel_idr_lock;
struct idr l2tp_tunnel_idr; struct idr l2tp_tunnel_idr;
struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2]; /* Lock for write access to l2tp_v3_session_idr/htable */
/* Lock for write access to l2tp_session_hlist */ spinlock_t l2tp_session_idr_lock;
spinlock_t l2tp_session_hlist_lock; struct idr l2tp_v3_session_idr;
struct hlist_head l2tp_v3_session_htable[16];
}; };
static inline unsigned long l2tp_v3_session_hashkey(struct sock *sk, u32 session_id)
{
return ((unsigned long)sk) + session_id;
}
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
static bool l2tp_sk_is_v6(struct sock *sk) static bool l2tp_sk_is_v6(struct sock *sk)
{ {
...@@ -125,17 +131,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net) ...@@ -125,17 +131,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net)
return net_generic(net, l2tp_net_id); return net_generic(net, l2tp_net_id);
} }
/* Session hash global list for L2TPv3.
* The session_id SHOULD be random according to RFC3931, but several
* L2TP implementations use incrementing session_ids. So we do a real
* hash on the session_id, rather than a simple bitmask.
*/
static inline struct hlist_head *
l2tp_session_id_hash_2(struct l2tp_net *pn, u32 session_id)
{
return &pn->l2tp_session_hlist[hash_32(session_id, L2TP_HASH_BITS_2)];
}
/* Session hash list. /* Session hash list.
* The session_id SHOULD be random according to RFC2661, but several * The session_id SHOULD be random according to RFC2661, but several
* L2TP implementations (Cisco and Microsoft) use incrementing * L2TP implementations (Cisco and Microsoft) use incrementing
...@@ -262,26 +257,40 @@ struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel, ...@@ -262,26 +257,40 @@ struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel,
} }
EXPORT_SYMBOL_GPL(l2tp_tunnel_get_session); EXPORT_SYMBOL_GPL(l2tp_tunnel_get_session);
struct l2tp_session *l2tp_session_get(const struct net *net, u32 session_id) struct l2tp_session *l2tp_v3_session_get(const struct net *net, struct sock *sk, u32 session_id)
{ {
struct hlist_head *session_list; const struct l2tp_net *pn = l2tp_pernet(net);
struct l2tp_session *session; struct l2tp_session *session;
session_list = l2tp_session_id_hash_2(l2tp_pernet(net), session_id);
rcu_read_lock_bh(); rcu_read_lock_bh();
hlist_for_each_entry_rcu(session, session_list, global_hlist) session = idr_find(&pn->l2tp_v3_session_idr, session_id);
if (session->session_id == session_id) { if (session && !hash_hashed(&session->hlist) &&
l2tp_session_inc_refcount(session); refcount_inc_not_zero(&session->ref_count)) {
rcu_read_unlock_bh(); rcu_read_unlock_bh();
return session;
}
/* If we get here and session is non-NULL, the session_id
* collides with one in another tunnel. If sk is non-NULL,
* find the session matching sk.
*/
if (session && sk) {
unsigned long key = l2tp_v3_session_hashkey(sk, session->session_id);
hash_for_each_possible_rcu(pn->l2tp_v3_session_htable, session,
hlist, key) {
if (session->tunnel->sock == sk &&
refcount_inc_not_zero(&session->ref_count)) {
rcu_read_unlock_bh();
return session; return session;
} }
}
}
rcu_read_unlock_bh(); rcu_read_unlock_bh();
return NULL; return NULL;
} }
EXPORT_SYMBOL_GPL(l2tp_session_get); EXPORT_SYMBOL_GPL(l2tp_v3_session_get);
struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth) struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth)
{ {
...@@ -313,12 +322,12 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, ...@@ -313,12 +322,12 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
const char *ifname) const char *ifname)
{ {
struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_net *pn = l2tp_pernet(net);
int hash; unsigned long session_id, tmp;
struct l2tp_session *session; struct l2tp_session *session;
rcu_read_lock_bh(); rcu_read_lock_bh();
for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) { idr_for_each_entry_ul(&pn->l2tp_v3_session_idr, session, tmp, session_id) {
hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) { if (session) {
if (!strcmp(session->ifname, ifname)) { if (!strcmp(session->ifname, ifname)) {
l2tp_session_inc_refcount(session); l2tp_session_inc_refcount(session);
rcu_read_unlock_bh(); rcu_read_unlock_bh();
...@@ -334,13 +343,106 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, ...@@ -334,13 +343,106 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
} }
EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname); EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname);
static void l2tp_session_coll_list_add(struct l2tp_session_coll_list *clist,
struct l2tp_session *session)
{
l2tp_session_inc_refcount(session);
WARN_ON_ONCE(session->coll_list);
session->coll_list = clist;
spin_lock(&clist->lock);
list_add(&session->clist, &clist->list);
spin_unlock(&clist->lock);
}
static int l2tp_session_collision_add(struct l2tp_net *pn,
struct l2tp_session *session1,
struct l2tp_session *session2)
{
struct l2tp_session_coll_list *clist;
lockdep_assert_held(&pn->l2tp_session_idr_lock);
if (!session2)
return -EEXIST;
/* If existing session is in IP-encap tunnel, refuse new session */
if (session2->tunnel->encap == L2TP_ENCAPTYPE_IP)
return -EEXIST;
clist = session2->coll_list;
if (!clist) {
/* First collision. Allocate list to manage the collided sessions
* and add the existing session to the list.
*/
clist = kmalloc(sizeof(*clist), GFP_ATOMIC);
if (!clist)
return -ENOMEM;
spin_lock_init(&clist->lock);
INIT_LIST_HEAD(&clist->list);
refcount_set(&clist->ref_count, 1);
l2tp_session_coll_list_add(clist, session2);
}
/* If existing session isn't already in the session hlist, add it. */
if (!hash_hashed(&session2->hlist))
hash_add(pn->l2tp_v3_session_htable, &session2->hlist,
session2->hlist_key);
/* Add new session to the hlist and collision list */
hash_add(pn->l2tp_v3_session_htable, &session1->hlist,
session1->hlist_key);
refcount_inc(&clist->ref_count);
l2tp_session_coll_list_add(clist, session1);
return 0;
}
static void l2tp_session_collision_del(struct l2tp_net *pn,
struct l2tp_session *session)
{
struct l2tp_session_coll_list *clist = session->coll_list;
unsigned long session_key = session->session_id;
struct l2tp_session *session2;
lockdep_assert_held(&pn->l2tp_session_idr_lock);
hash_del(&session->hlist);
if (clist) {
/* Remove session from its collision list. If there
* are other sessions with the same ID, replace this
* session's IDR entry with that session, otherwise
* remove the IDR entry. If this is the last session,
* the collision list data is freed.
*/
spin_lock(&clist->lock);
list_del_init(&session->clist);
session2 = list_first_entry_or_null(&clist->list, struct l2tp_session, clist);
if (session2) {
void *old = idr_replace(&pn->l2tp_v3_session_idr, session2, session_key);
WARN_ON_ONCE(IS_ERR_VALUE(old));
} else {
void *removed = idr_remove(&pn->l2tp_v3_session_idr, session_key);
WARN_ON_ONCE(removed != session);
}
session->coll_list = NULL;
spin_unlock(&clist->lock);
if (refcount_dec_and_test(&clist->ref_count))
kfree(clist);
l2tp_session_dec_refcount(session);
}
}
int l2tp_session_register(struct l2tp_session *session, int l2tp_session_register(struct l2tp_session *session,
struct l2tp_tunnel *tunnel) struct l2tp_tunnel *tunnel)
{ {
struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net);
struct l2tp_session *session_walk; struct l2tp_session *session_walk;
struct hlist_head *g_head;
struct hlist_head *head; struct hlist_head *head;
struct l2tp_net *pn; u32 session_key;
int err; int err;
head = l2tp_session_id_hash(tunnel, session->session_id); head = l2tp_session_id_hash(tunnel, session->session_id);
...@@ -358,39 +460,45 @@ int l2tp_session_register(struct l2tp_session *session, ...@@ -358,39 +460,45 @@ int l2tp_session_register(struct l2tp_session *session,
} }
if (tunnel->version == L2TP_HDR_VER_3) { if (tunnel->version == L2TP_HDR_VER_3) {
pn = l2tp_pernet(tunnel->l2tp_net); session_key = session->session_id;
g_head = l2tp_session_id_hash_2(pn, session->session_id); spin_lock_bh(&pn->l2tp_session_idr_lock);
err = idr_alloc_u32(&pn->l2tp_v3_session_idr, NULL,
spin_lock_bh(&pn->l2tp_session_hlist_lock); &session_key, session_key, GFP_ATOMIC);
/* IP encap expects session IDs to be globally unique, while /* IP encap expects session IDs to be globally unique, while
* UDP encap doesn't. * UDP encap doesn't. This isn't per the RFC, which says that
* sessions are identified only by the session ID, but is to
* support existing userspace which depends on it.
*/ */
hlist_for_each_entry(session_walk, g_head, global_hlist) if (err == -ENOSPC && tunnel->encap == L2TP_ENCAPTYPE_UDP) {
if (session_walk->session_id == session->session_id && struct l2tp_session *session2;
(session_walk->tunnel->encap == L2TP_ENCAPTYPE_IP ||
tunnel->encap == L2TP_ENCAPTYPE_IP)) { session2 = idr_find(&pn->l2tp_v3_session_idr,
session_key);
err = l2tp_session_collision_add(pn, session, session2);
}
spin_unlock_bh(&pn->l2tp_session_idr_lock);
if (err == -ENOSPC)
err = -EEXIST; err = -EEXIST;
goto err_tlock_pnlock;
} }
l2tp_tunnel_inc_refcount(tunnel); if (err)
hlist_add_head_rcu(&session->global_hlist, g_head); goto err_tlock;
spin_unlock_bh(&pn->l2tp_session_hlist_lock);
} else {
l2tp_tunnel_inc_refcount(tunnel); l2tp_tunnel_inc_refcount(tunnel);
}
hlist_add_head_rcu(&session->hlist, head); hlist_add_head_rcu(&session->hlist, head);
spin_unlock_bh(&tunnel->hlist_lock); spin_unlock_bh(&tunnel->hlist_lock);
if (tunnel->version == L2TP_HDR_VER_3) {
spin_lock_bh(&pn->l2tp_session_idr_lock);
idr_replace(&pn->l2tp_v3_session_idr, session, session_key);
spin_unlock_bh(&pn->l2tp_session_idr_lock);
}
trace_register_session(session); trace_register_session(session);
return 0; return 0;
err_tlock_pnlock:
spin_unlock_bh(&pn->l2tp_session_hlist_lock);
err_tlock: err_tlock:
spin_unlock_bh(&tunnel->hlist_lock); spin_unlock_bh(&tunnel->hlist_lock);
...@@ -1218,13 +1326,19 @@ static void l2tp_session_unhash(struct l2tp_session *session) ...@@ -1218,13 +1326,19 @@ static void l2tp_session_unhash(struct l2tp_session *session)
hlist_del_init_rcu(&session->hlist); hlist_del_init_rcu(&session->hlist);
spin_unlock_bh(&tunnel->hlist_lock); spin_unlock_bh(&tunnel->hlist_lock);
/* For L2TPv3 we have a per-net hash: remove from there, too */ /* For L2TPv3 we have a per-net IDR: remove from there, too */
if (tunnel->version != L2TP_HDR_VER_2) { if (tunnel->version == L2TP_HDR_VER_3) {
struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net); struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net);
struct l2tp_session *removed = session;
spin_lock_bh(&pn->l2tp_session_hlist_lock); spin_lock_bh(&pn->l2tp_session_idr_lock);
hlist_del_init_rcu(&session->global_hlist); if (hash_hashed(&session->hlist))
spin_unlock_bh(&pn->l2tp_session_hlist_lock); l2tp_session_collision_del(pn, session);
else
removed = idr_remove(&pn->l2tp_v3_session_idr,
session->session_id);
WARN_ON_ONCE(removed && removed != session);
spin_unlock_bh(&pn->l2tp_session_idr_lock);
} }
synchronize_rcu(); synchronize_rcu();
...@@ -1649,8 +1763,9 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn ...@@ -1649,8 +1763,9 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
skb_queue_head_init(&session->reorder_q); skb_queue_head_init(&session->reorder_q);
session->hlist_key = l2tp_v3_session_hashkey(tunnel->sock, session->session_id);
INIT_HLIST_NODE(&session->hlist); INIT_HLIST_NODE(&session->hlist);
INIT_HLIST_NODE(&session->global_hlist); INIT_LIST_HEAD(&session->clist);
if (cfg) { if (cfg) {
session->pwtype = cfg->pw_type; session->pwtype = cfg->pw_type;
...@@ -1683,15 +1798,12 @@ EXPORT_SYMBOL_GPL(l2tp_session_create); ...@@ -1683,15 +1798,12 @@ EXPORT_SYMBOL_GPL(l2tp_session_create);
static __net_init int l2tp_init_net(struct net *net) static __net_init int l2tp_init_net(struct net *net)
{ {
struct l2tp_net *pn = net_generic(net, l2tp_net_id); struct l2tp_net *pn = net_generic(net, l2tp_net_id);
int hash;
idr_init(&pn->l2tp_tunnel_idr); idr_init(&pn->l2tp_tunnel_idr);
spin_lock_init(&pn->l2tp_tunnel_idr_lock); spin_lock_init(&pn->l2tp_tunnel_idr_lock);
for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) idr_init(&pn->l2tp_v3_session_idr);
INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]); spin_lock_init(&pn->l2tp_session_idr_lock);
spin_lock_init(&pn->l2tp_session_hlist_lock);
return 0; return 0;
} }
...@@ -1701,7 +1813,6 @@ static __net_exit void l2tp_exit_net(struct net *net) ...@@ -1701,7 +1813,6 @@ static __net_exit void l2tp_exit_net(struct net *net)
struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_net *pn = l2tp_pernet(net);
struct l2tp_tunnel *tunnel = NULL; struct l2tp_tunnel *tunnel = NULL;
unsigned long tunnel_id, tmp; unsigned long tunnel_id, tmp;
int hash;
rcu_read_lock_bh(); rcu_read_lock_bh();
idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
...@@ -1714,8 +1825,7 @@ static __net_exit void l2tp_exit_net(struct net *net) ...@@ -1714,8 +1825,7 @@ static __net_exit void l2tp_exit_net(struct net *net)
flush_workqueue(l2tp_wq); flush_workqueue(l2tp_wq);
rcu_barrier(); rcu_barrier();
for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) idr_destroy(&pn->l2tp_v3_session_idr);
WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash]));
idr_destroy(&pn->l2tp_tunnel_idr); idr_destroy(&pn->l2tp_tunnel_idr);
} }
......
...@@ -23,10 +23,6 @@ ...@@ -23,10 +23,6 @@
#define L2TP_HASH_BITS 4 #define L2TP_HASH_BITS 4
#define L2TP_HASH_SIZE BIT(L2TP_HASH_BITS) #define L2TP_HASH_SIZE BIT(L2TP_HASH_BITS)
/* System-wide session hash table size */
#define L2TP_HASH_BITS_2 8
#define L2TP_HASH_SIZE_2 BIT(L2TP_HASH_BITS_2)
struct sk_buff; struct sk_buff;
struct l2tp_stats { struct l2tp_stats {
...@@ -61,6 +57,12 @@ struct l2tp_session_cfg { ...@@ -61,6 +57,12 @@ struct l2tp_session_cfg {
char *ifname; char *ifname;
}; };
struct l2tp_session_coll_list {
spinlock_t lock; /* for access to list */
struct list_head list;
refcount_t ref_count;
};
/* Represents a session (pseudowire) instance. /* Represents a session (pseudowire) instance.
* Tracks runtime state including cookies, dataplane packet sequencing, and IO statistics. * Tracks runtime state including cookies, dataplane packet sequencing, and IO statistics.
* Is linked into a per-tunnel session hashlist; and in the case of an L2TPv3 session into * Is linked into a per-tunnel session hashlist; and in the case of an L2TPv3 session into
...@@ -88,8 +90,11 @@ struct l2tp_session { ...@@ -88,8 +90,11 @@ struct l2tp_session {
u32 nr_oos; /* NR of last OOS packet */ u32 nr_oos; /* NR of last OOS packet */
int nr_oos_count; /* for OOS recovery */ int nr_oos_count; /* for OOS recovery */
int nr_oos_count_max; int nr_oos_count_max;
struct hlist_node hlist; /* hash list node */
refcount_t ref_count; refcount_t ref_count;
struct hlist_node hlist; /* per-net session hlist */
unsigned long hlist_key; /* key for session hlist */
struct l2tp_session_coll_list *coll_list; /* session collision list */
struct list_head clist; /* for coll_list */
char name[L2TP_SESSION_NAME_MAX]; /* for logging */ char name[L2TP_SESSION_NAME_MAX]; /* for logging */
char ifname[IFNAMSIZ]; char ifname[IFNAMSIZ];
...@@ -102,7 +107,6 @@ struct l2tp_session { ...@@ -102,7 +107,6 @@ struct l2tp_session {
int reorder_skip; /* set if skip to next nr */ int reorder_skip; /* set if skip to next nr */
enum l2tp_pwtype pwtype; enum l2tp_pwtype pwtype;
struct l2tp_stats stats; struct l2tp_stats stats;
struct hlist_node global_hlist; /* global hash list node */
/* Session receive handler for data packets. /* Session receive handler for data packets.
* Each pseudowire implementation should implement this callback in order to * Each pseudowire implementation should implement this callback in order to
...@@ -226,7 +230,7 @@ struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth); ...@@ -226,7 +230,7 @@ struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth);
struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel, struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel,
u32 session_id); u32 session_id);
struct l2tp_session *l2tp_session_get(const struct net *net, u32 session_id); struct l2tp_session *l2tp_v3_session_get(const struct net *net, struct sock *sk, u32 session_id);
struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth); struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth);
struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
const char *ifname); const char *ifname);
......
...@@ -140,7 +140,7 @@ static int l2tp_ip_recv(struct sk_buff *skb) ...@@ -140,7 +140,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
} }
/* Ok, this is a data packet. Lookup the session. */ /* Ok, this is a data packet. Lookup the session. */
session = l2tp_session_get(net, session_id); session = l2tp_v3_session_get(net, NULL, session_id);
if (!session) if (!session)
goto discard; goto discard;
......
...@@ -150,7 +150,7 @@ static int l2tp_ip6_recv(struct sk_buff *skb) ...@@ -150,7 +150,7 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
} }
/* Ok, this is a data packet. Lookup the session. */ /* Ok, this is a data packet. Lookup the session. */
session = l2tp_session_get(net, session_id); session = l2tp_v3_session_get(net, NULL, session_id);
if (!session) if (!session)
goto discard; goto discard;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment