Commit 4638de5a authored by Kishen Maloor's avatar Kishen Maloor Committed by David S. Miller

mptcp: handle local addrs announced by userspace PMs

This change adds an internal function to store/retrieve local
addrs announced by userspace PM implementations to/from its kernel
context. The function addresses the requirements of three scenarios:
1) ADD_ADDR announcements (which require that a local id be
provided), 2) retrieving the local id associated with an address,
and also where one may need to be assigned, and 3) reissuance of
ADD_ADDRs when there's a successful match of addr/id.

The list of all stored local addr entries is held under the
MPTCP sock structure. Memory for these entries is allocated from
the sock option buffer, so the list of addrs is bounded by optmem_max.
The list if not released via REMOVE_ADDR signals is ultimately
freed when the sock is destructed.
Acked-by: default avatarPaolo Abeni <pabeni@redhat.com>
Signed-off-by: default avatarKishen Maloor <kishen.maloor@intel.com>
Signed-off-by: default avatarMat Martineau <mathew.j.martineau@linux.intel.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent f43f0cd2
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
obj-$(CONFIG_MPTCP) += mptcp.o obj-$(CONFIG_MPTCP) += mptcp.o
mptcp-y := protocol.o subflow.o options.o token.o crypto.o ctrl.o pm.o diag.o \ mptcp-y := protocol.o subflow.o options.o token.o crypto.o ctrl.o pm.o diag.o \
mib.o pm_netlink.o sockopt.o mib.o pm_netlink.o sockopt.o pm_userspace.o
obj-$(CONFIG_SYN_COOKIES) += syncookies.o obj-$(CONFIG_SYN_COOKIES) += syncookies.o
obj-$(CONFIG_INET_MPTCP_DIAG) += mptcp_diag.o obj-$(CONFIG_INET_MPTCP_DIAG) += mptcp_diag.o
......
...@@ -469,6 +469,7 @@ void mptcp_pm_data_init(struct mptcp_sock *msk) ...@@ -469,6 +469,7 @@ void mptcp_pm_data_init(struct mptcp_sock *msk)
{ {
spin_lock_init(&msk->pm.lock); spin_lock_init(&msk->pm.lock);
INIT_LIST_HEAD(&msk->pm.anno_list); INIT_LIST_HEAD(&msk->pm.anno_list);
INIT_LIST_HEAD(&msk->pm.userspace_pm_local_addr_list);
mptcp_pm_data_reset(msk); mptcp_pm_data_reset(msk);
} }
......
...@@ -22,14 +22,6 @@ static struct genl_family mptcp_genl_family; ...@@ -22,14 +22,6 @@ static struct genl_family mptcp_genl_family;
static int pm_nl_pernet_id; static int pm_nl_pernet_id;
struct mptcp_pm_addr_entry {
struct list_head list;
struct mptcp_addr_info addr;
u8 flags;
int ifindex;
struct socket *lsk;
};
struct mptcp_pm_add_entry { struct mptcp_pm_add_entry {
struct list_head list; struct list_head list;
struct mptcp_addr_info addr; struct mptcp_addr_info addr;
...@@ -66,8 +58,8 @@ pm_nl_get_pernet_from_msk(const struct mptcp_sock *msk) ...@@ -66,8 +58,8 @@ pm_nl_get_pernet_from_msk(const struct mptcp_sock *msk)
return pm_nl_get_pernet(sock_net((struct sock *)msk)); return pm_nl_get_pernet(sock_net((struct sock *)msk));
} }
static bool addresses_equal(const struct mptcp_addr_info *a, bool mptcp_addresses_equal(const struct mptcp_addr_info *a,
const struct mptcp_addr_info *b, bool use_port) const struct mptcp_addr_info *b, bool use_port)
{ {
bool addr_equals = false; bool addr_equals = false;
...@@ -131,7 +123,7 @@ static bool lookup_subflow_by_saddr(const struct list_head *list, ...@@ -131,7 +123,7 @@ static bool lookup_subflow_by_saddr(const struct list_head *list,
skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
local_address(skc, &cur); local_address(skc, &cur);
if (addresses_equal(&cur, saddr, saddr->port)) if (mptcp_addresses_equal(&cur, saddr, saddr->port))
return true; return true;
} }
...@@ -149,7 +141,7 @@ static bool lookup_subflow_by_daddr(const struct list_head *list, ...@@ -149,7 +141,7 @@ static bool lookup_subflow_by_daddr(const struct list_head *list,
skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
remote_address(skc, &cur); remote_address(skc, &cur);
if (addresses_equal(&cur, daddr, daddr->port)) if (mptcp_addresses_equal(&cur, daddr, daddr->port))
return true; return true;
} }
...@@ -269,7 +261,7 @@ mptcp_lookup_anno_list_by_saddr(const struct mptcp_sock *msk, ...@@ -269,7 +261,7 @@ mptcp_lookup_anno_list_by_saddr(const struct mptcp_sock *msk,
lockdep_assert_held(&msk->pm.lock); lockdep_assert_held(&msk->pm.lock);
list_for_each_entry(entry, &msk->pm.anno_list, list) { list_for_each_entry(entry, &msk->pm.anno_list, list) {
if (addresses_equal(&entry->addr, addr, true)) if (mptcp_addresses_equal(&entry->addr, addr, true))
return entry; return entry;
} }
...@@ -286,7 +278,7 @@ bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk) ...@@ -286,7 +278,7 @@ bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk)
spin_lock_bh(&msk->pm.lock); spin_lock_bh(&msk->pm.lock);
list_for_each_entry(entry, &msk->pm.anno_list, list) { list_for_each_entry(entry, &msk->pm.anno_list, list) {
if (addresses_equal(&entry->addr, &saddr, true)) { if (mptcp_addresses_equal(&entry->addr, &saddr, true)) {
ret = true; ret = true;
goto out; goto out;
} }
...@@ -421,7 +413,7 @@ static bool lookup_address_in_vec(const struct mptcp_addr_info *addrs, unsigned ...@@ -421,7 +413,7 @@ static bool lookup_address_in_vec(const struct mptcp_addr_info *addrs, unsigned
int i; int i;
for (i = 0; i < nr; i++) { for (i = 0; i < nr; i++) {
if (addresses_equal(&addrs[i], addr, addr->port)) if (mptcp_addresses_equal(&addrs[i], addr, addr->port))
return true; return true;
} }
...@@ -457,7 +449,7 @@ static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullm ...@@ -457,7 +449,7 @@ static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullm
mptcp_for_each_subflow(msk, subflow) { mptcp_for_each_subflow(msk, subflow) {
ssk = mptcp_subflow_tcp_sock(subflow); ssk = mptcp_subflow_tcp_sock(subflow);
remote_address((struct sock_common *)ssk, &addrs[i]); remote_address((struct sock_common *)ssk, &addrs[i]);
if (deny_id0 && addresses_equal(&addrs[i], &remote, false)) if (deny_id0 && mptcp_addresses_equal(&addrs[i], &remote, false))
continue; continue;
if (!lookup_address_in_vec(addrs, i, &addrs[i]) && if (!lookup_address_in_vec(addrs, i, &addrs[i]) &&
...@@ -490,7 +482,7 @@ __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info, ...@@ -490,7 +482,7 @@ __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info,
struct mptcp_pm_addr_entry *entry; struct mptcp_pm_addr_entry *entry;
list_for_each_entry(entry, &pernet->local_addr_list, list) { list_for_each_entry(entry, &pernet->local_addr_list, list) {
if ((!lookup_by_id && addresses_equal(&entry->addr, info, true)) || if ((!lookup_by_id && mptcp_addresses_equal(&entry->addr, info, true)) ||
(lookup_by_id && entry->addr.id == info->id)) (lookup_by_id && entry->addr.id == info->id))
return entry; return entry;
} }
...@@ -505,7 +497,7 @@ lookup_id_by_addr(const struct pm_nl_pernet *pernet, const struct mptcp_addr_inf ...@@ -505,7 +497,7 @@ lookup_id_by_addr(const struct pm_nl_pernet *pernet, const struct mptcp_addr_inf
rcu_read_lock(); rcu_read_lock();
list_for_each_entry(entry, &pernet->local_addr_list, list) { list_for_each_entry(entry, &pernet->local_addr_list, list) {
if (addresses_equal(&entry->addr, addr, entry->addr.port)) { if (mptcp_addresses_equal(&entry->addr, addr, entry->addr.port)) {
ret = entry->addr.id; ret = entry->addr.id;
break; break;
} }
...@@ -739,7 +731,7 @@ static int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, ...@@ -739,7 +731,7 @@ static int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
struct mptcp_addr_info local; struct mptcp_addr_info local;
local_address((struct sock_common *)ssk, &local); local_address((struct sock_common *)ssk, &local);
if (!addresses_equal(&local, addr, addr->port)) if (!mptcp_addresses_equal(&local, addr, addr->port))
continue; continue;
if (subflow->backup != bkup) if (subflow->backup != bkup)
...@@ -909,9 +901,9 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, ...@@ -909,9 +901,9 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
* singled addresses * singled addresses
*/ */
list_for_each_entry(cur, &pernet->local_addr_list, list) { list_for_each_entry(cur, &pernet->local_addr_list, list) {
if (addresses_equal(&cur->addr, &entry->addr, if (mptcp_addresses_equal(&cur->addr, &entry->addr,
address_use_port(entry) && address_use_port(entry) &&
address_use_port(cur))) { address_use_port(cur))) {
/* allow replacing the exiting endpoint only if such /* allow replacing the exiting endpoint only if such
* endpoint is an implicit one and the user-space * endpoint is an implicit one and the user-space
* did not provide an endpoint id * did not provide an endpoint id
...@@ -1038,14 +1030,14 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) ...@@ -1038,14 +1030,14 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
*/ */
local_address((struct sock_common *)msk, &msk_local); local_address((struct sock_common *)msk, &msk_local);
local_address((struct sock_common *)skc, &skc_local); local_address((struct sock_common *)skc, &skc_local);
if (addresses_equal(&msk_local, &skc_local, false)) if (mptcp_addresses_equal(&msk_local, &skc_local, false))
return 0; return 0;
pernet = pm_nl_get_pernet_from_msk(msk); pernet = pm_nl_get_pernet_from_msk(msk);
rcu_read_lock(); rcu_read_lock();
list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) { if (mptcp_addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
ret = entry->addr.id; ret = entry->addr.id;
break; break;
} }
...@@ -1416,7 +1408,7 @@ static int mptcp_nl_remove_id_zero_address(struct net *net, ...@@ -1416,7 +1408,7 @@ static int mptcp_nl_remove_id_zero_address(struct net *net,
goto next; goto next;
local_address((struct sock_common *)msk, &msk_local); local_address((struct sock_common *)msk, &msk_local);
if (!addresses_equal(&msk_local, addr, addr->port)) if (!mptcp_addresses_equal(&msk_local, addr, addr->port))
goto next; goto next;
lock_sock(sk); lock_sock(sk);
......
// SPDX-License-Identifier: GPL-2.0
/* Multipath TCP
*
* Copyright (c) 2022, Intel Corporation.
*/
#include "protocol.h"
void mptcp_free_local_addr_list(struct mptcp_sock *msk)
{
struct mptcp_pm_addr_entry *entry, *tmp;
struct sock *sk = (struct sock *)msk;
LIST_HEAD(free_list);
if (!mptcp_pm_is_userspace(msk))
return;
spin_lock_bh(&msk->pm.lock);
list_splice_init(&msk->pm.userspace_pm_local_addr_list, &free_list);
spin_unlock_bh(&msk->pm.lock);
list_for_each_entry_safe(entry, tmp, &free_list, list) {
sock_kfree_s(sk, entry, sizeof(*entry));
}
}
int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk,
struct mptcp_pm_addr_entry *entry)
{
DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
struct mptcp_pm_addr_entry *match = NULL;
struct sock *sk = (struct sock *)msk;
struct mptcp_pm_addr_entry *e;
bool addr_match = false;
bool id_match = false;
int ret = -EINVAL;
bitmap_zero(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
spin_lock_bh(&msk->pm.lock);
list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) {
addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true);
if (addr_match && entry->addr.id == 0)
entry->addr.id = e->addr.id;
id_match = (e->addr.id == entry->addr.id);
if (addr_match && id_match) {
match = e;
break;
} else if (addr_match || id_match) {
break;
}
__set_bit(e->addr.id, id_bitmap);
}
if (!match && !addr_match && !id_match) {
/* Memory for the entry is allocated from the
* sock option buffer.
*/
e = sock_kmalloc(sk, sizeof(*e), GFP_ATOMIC);
if (!e) {
spin_unlock_bh(&msk->pm.lock);
return -ENOMEM;
}
*e = *entry;
if (!e->addr.id)
e->addr.id = find_next_zero_bit(id_bitmap,
MPTCP_PM_MAX_ADDR_ID + 1,
1);
list_add_tail_rcu(&e->list, &msk->pm.userspace_pm_local_addr_list);
ret = e->addr.id;
} else if (match) {
ret = entry->addr.id;
}
spin_unlock_bh(&msk->pm.lock);
return ret;
}
...@@ -3097,6 +3097,7 @@ void mptcp_destroy_common(struct mptcp_sock *msk) ...@@ -3097,6 +3097,7 @@ void mptcp_destroy_common(struct mptcp_sock *msk)
msk->rmem_fwd_alloc = 0; msk->rmem_fwd_alloc = 0;
mptcp_token_destroy(msk); mptcp_token_destroy(msk);
mptcp_pm_free_anno_list(msk); mptcp_pm_free_anno_list(msk);
mptcp_free_local_addr_list(msk);
} }
static void mptcp_destroy(struct sock *sk) static void mptcp_destroy(struct sock *sk)
......
...@@ -208,6 +208,7 @@ struct mptcp_pm_data { ...@@ -208,6 +208,7 @@ struct mptcp_pm_data {
struct mptcp_addr_info local; struct mptcp_addr_info local;
struct mptcp_addr_info remote; struct mptcp_addr_info remote;
struct list_head anno_list; struct list_head anno_list;
struct list_head userspace_pm_local_addr_list;
spinlock_t lock; /*protects the whole PM data */ spinlock_t lock; /*protects the whole PM data */
...@@ -228,6 +229,14 @@ struct mptcp_pm_data { ...@@ -228,6 +229,14 @@ struct mptcp_pm_data {
struct mptcp_rm_list rm_list_rx; struct mptcp_rm_list rm_list_rx;
}; };
struct mptcp_pm_addr_entry {
struct list_head list;
struct mptcp_addr_info addr;
u8 flags;
int ifindex;
struct socket *lsk;
};
struct mptcp_data_frag { struct mptcp_data_frag {
struct list_head list; struct list_head list;
u64 data_seq; u64 data_seq;
...@@ -601,6 +610,9 @@ void mptcp_subflow_reset(struct sock *ssk); ...@@ -601,6 +610,9 @@ void mptcp_subflow_reset(struct sock *ssk);
void mptcp_sock_graft(struct sock *sk, struct socket *parent); void mptcp_sock_graft(struct sock *sk, struct socket *parent);
struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk); struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk);
bool mptcp_addresses_equal(const struct mptcp_addr_info *a,
const struct mptcp_addr_info *b, bool use_port);
/* called with sk socket lock held */ /* called with sk socket lock held */
int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc, int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc,
const struct mptcp_addr_info *remote); const struct mptcp_addr_info *remote);
...@@ -779,6 +791,9 @@ int mptcp_pm_announce_addr(struct mptcp_sock *msk, ...@@ -779,6 +791,9 @@ int mptcp_pm_announce_addr(struct mptcp_sock *msk,
bool echo); bool echo);
int mptcp_pm_remove_addr(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list); int mptcp_pm_remove_addr(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list);
int mptcp_pm_remove_subflow(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list); int mptcp_pm_remove_subflow(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list);
int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk,
struct mptcp_pm_addr_entry *entry);
void mptcp_free_local_addr_list(struct mptcp_sock *msk);
void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk, void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
const struct sock *ssk, gfp_t gfp); const struct sock *ssk, gfp_t gfp);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment