Commit 034508c4 authored by David S. Miller's avatar David S. Miller

Merge branch 'mptcp-introduce-msk-diag-interface'

Paolo Abeni says:

====================
mptcp: introduce msk diag interface

This series implements the diag interface for the MPTCP sockets.

Since the MPTCP protocol value can't be represented with the
current diag uAPI, the first patch introduces an extended attribute
allowing user-space to specify lager protocol values.

The token APIs are then extended to allow traversing the
whole token container.

Patch 3 carries the actual diag interface implementation, and
later patch bring-in some functional self-tests.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 5ca670e5 df62f2ec
......@@ -65,6 +65,7 @@ enum {
INET_DIAG_REQ_NONE,
INET_DIAG_REQ_BYTECODE,
INET_DIAG_REQ_SK_BPF_STORAGES,
INET_DIAG_REQ_PROTOCOL,
__INET_DIAG_REQ_MAX,
};
......
......@@ -86,4 +86,21 @@ enum {
__MPTCP_PM_CMD_AFTER_LAST
};
#define MPTCP_INFO_FLAG_FALLBACK _BITUL(0)
#define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED _BITUL(1)
struct mptcp_info {
__u8 mptcpi_subflows;
__u8 mptcpi_add_addr_signal;
__u8 mptcpi_add_addr_accepted;
__u8 mptcpi_subflows_max;
__u8 mptcpi_add_addr_signal_max;
__u8 mptcpi_add_addr_accepted_max;
__u32 mptcpi_flags;
__u32 mptcpi_token;
__u64 mptcpi_write_seq;
__u64 mptcpi_snd_una;
__u64 mptcpi_rcv_nxt;
};
#endif /* _UAPI_MPTCP_H */
......@@ -3566,6 +3566,7 @@ int sock_load_diag_module(int family, int protocol)
#ifdef CONFIG_INET
if (family == AF_INET &&
protocol != IPPROTO_RAW &&
protocol < MAX_INET_PROTOS &&
!rcu_access_pointer(inet_protos[protocol]))
return -ENOENT;
#endif
......
......@@ -52,6 +52,11 @@ static DEFINE_MUTEX(inet_diag_table_mutex);
static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
{
if (proto < 0 || proto >= IPPROTO_MAX) {
mutex_lock(&inet_diag_table_mutex);
return ERR_PTR(-ENOENT);
}
if (!inet_diag_table[proto])
sock_load_diag_module(AF_INET, proto);
......@@ -181,6 +186,28 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
}
EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
static void inet_diag_parse_attrs(const struct nlmsghdr *nlh, int hdrlen,
struct nlattr **req_nlas)
{
struct nlattr *nla;
int remaining;
nlmsg_for_each_attr(nla, nlh, hdrlen, remaining) {
int type = nla_type(nla);
if (type < __INET_DIAG_REQ_MAX)
req_nlas[type] = nla;
}
}
static int inet_diag_get_protocol(const struct inet_diag_req_v2 *req,
const struct inet_diag_dump_data *data)
{
if (data->req_nlas[INET_DIAG_REQ_PROTOCOL])
return nla_get_u32(data->req_nlas[INET_DIAG_REQ_PROTOCOL]);
return req->sdiag_protocol;
}
#define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
......@@ -198,7 +225,7 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
void *info = NULL;
cb_data = cb->data;
handler = inet_diag_table[req->sdiag_protocol];
handler = inet_diag_table[inet_diag_get_protocol(req, cb_data)];
BUG_ON(!handler);
nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
......@@ -539,20 +566,25 @@ EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
const struct nlmsghdr *nlh,
int hdrlen,
const struct inet_diag_req_v2 *req)
{
const struct inet_diag_handler *handler;
int err;
struct inet_diag_dump_data dump_data;
int err, protocol;
handler = inet_diag_lock_handler(req->sdiag_protocol);
memset(&dump_data, 0, sizeof(dump_data));
inet_diag_parse_attrs(nlh, hdrlen, dump_data.req_nlas);
protocol = inet_diag_get_protocol(req, &dump_data);
handler = inet_diag_lock_handler(protocol);
if (IS_ERR(handler)) {
err = PTR_ERR(handler);
} else if (cmd == SOCK_DIAG_BY_FAMILY) {
struct inet_diag_dump_data empty_dump_data = {};
struct netlink_callback cb = {
.nlh = nlh,
.skb = in_skb,
.data = &empty_dump_data,
.data = &dump_data,
};
err = handler->dump_one(&cb, req);
} else if (cmd == SOCK_DESTROY && handler->destroy) {
......@@ -1103,13 +1135,16 @@ EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r)
{
struct inet_diag_dump_data *cb_data = cb->data;
const struct inet_diag_handler *handler;
u32 prev_min_dump_alloc;
int err = 0;
int protocol, err = 0;
protocol = inet_diag_get_protocol(r, cb_data);
again:
prev_min_dump_alloc = cb->min_dump_alloc;
handler = inet_diag_lock_handler(r->sdiag_protocol);
handler = inet_diag_lock_handler(protocol);
if (!IS_ERR(handler))
handler->dump(skb, cb, r);
else
......@@ -1139,19 +1174,13 @@ static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen)
struct inet_diag_dump_data *cb_data;
struct sk_buff *skb = cb->skb;
struct nlattr *nla;
int rem, err;
int err;
cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL);
if (!cb_data)
return -ENOMEM;
nla_for_each_attr(nla, nlmsg_attrdata(nlh, hdrlen),
nlmsg_attrlen(nlh, hdrlen), rem) {
int type = nla_type(nla);
if (type < __INET_DIAG_REQ_MAX)
cb_data->req_nlas[type] = nla;
}
inet_diag_parse_attrs(nlh, hdrlen, cb_data->req_nlas);
nla = cb_data->inet_diag_nla_bc;
if (nla) {
......@@ -1237,7 +1266,8 @@ static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
req.idiag_states = rc->idiag_states;
req.id = rc->id;
return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, &req);
return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh,
sizeof(struct inet_diag_req), &req);
}
static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
......@@ -1279,7 +1309,8 @@ static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
return netlink_dump_start(net->diag_nlsk, skb, h, &c);
}
return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
return inet_diag_cmd_exact(h->nlmsg_type, skb, h, hdrlen,
nlmsg_data(h));
}
static
......
......@@ -13,6 +13,10 @@ config MPTCP
if MPTCP
config INET_MPTCP_DIAG
depends on INET_DIAG
def_tristate INET_DIAG
config MPTCP_IPV6
bool "MPTCP: IPv6 support for Multipath TCP"
select IPV6
......
......@@ -4,6 +4,8 @@ obj-$(CONFIG_MPTCP) += mptcp.o
mptcp-y := protocol.o subflow.o options.o token.o crypto.o ctrl.o pm.o diag.o \
mib.o pm_netlink.o
obj-$(CONFIG_INET_MPTCP_DIAG) += mptcp_diag.o
mptcp_crypto_test-objs := crypto_test.o
mptcp_token_test-objs := token_test.o
obj-$(CONFIG_MPTCP_KUNIT_TESTS) += mptcp_crypto_test.o mptcp_token_test.o
// SPDX-License-Identifier: GPL-2.0
/* MPTCP socket monitoring support
*
* Copyright (c) 2020 Red Hat
*
* Author: Paolo Abeni <pabeni@redhat.com>
*/
#include <linux/kernel.h>
#include <linux/net.h>
#include <linux/inet_diag.h>
#include <net/netlink.h>
#include <uapi/linux/mptcp.h>
#include "protocol.h"
static int sk_diag_dump(struct sock *sk, struct sk_buff *skb,
struct netlink_callback *cb,
const struct inet_diag_req_v2 *req,
struct nlattr *bc, bool net_admin)
{
if (!inet_diag_bc_sk(bc, sk))
return 0;
return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, req, NLM_F_MULTI,
net_admin);
}
static int mptcp_diag_dump_one(struct netlink_callback *cb,
const struct inet_diag_req_v2 *req)
{
struct sk_buff *in_skb = cb->skb;
struct mptcp_sock *msk = NULL;
struct sk_buff *rep;
int err = -ENOENT;
struct net *net;
struct sock *sk;
net = sock_net(in_skb->sk);
msk = mptcp_token_get_sock(req->id.idiag_cookie[0]);
if (!msk)
goto out_nosk;
err = -ENOMEM;
sk = (struct sock *)msk;
rep = nlmsg_new(nla_total_size(sizeof(struct inet_diag_msg)) +
inet_diag_msg_attrs_size() +
nla_total_size(sizeof(struct mptcp_info)) +
nla_total_size(sizeof(struct inet_diag_meminfo)) + 64,
GFP_KERNEL);
if (!rep)
goto out;
err = inet_sk_diag_fill(sk, inet_csk(sk), rep, cb, req, 0,
netlink_net_capable(in_skb, CAP_NET_ADMIN));
if (err < 0) {
WARN_ON(err == -EMSGSIZE);
kfree_skb(rep);
goto out;
}
err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
MSG_DONTWAIT);
if (err > 0)
err = 0;
out:
sock_put(sk);
out_nosk:
return err;
}
static void mptcp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r)
{
bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
struct net *net = sock_net(skb->sk);
struct inet_diag_dump_data *cb_data;
struct mptcp_sock *msk;
struct nlattr *bc;
cb_data = cb->data;
bc = cb_data->inet_diag_nla_bc;
while ((msk = mptcp_token_iter_next(net, &cb->args[0], &cb->args[1])) !=
NULL) {
struct inet_sock *inet = (struct inet_sock *)msk;
struct sock *sk = (struct sock *)msk;
int ret = 0;
if (!(r->idiag_states & (1 << sk->sk_state)))
goto next;
if (r->sdiag_family != AF_UNSPEC &&
sk->sk_family != r->sdiag_family)
goto next;
if (r->id.idiag_sport != inet->inet_sport &&
r->id.idiag_sport)
goto next;
if (r->id.idiag_dport != inet->inet_dport &&
r->id.idiag_dport)
goto next;
ret = sk_diag_dump(sk, skb, cb, r, bc, net_admin);
next:
sock_put(sk);
if (ret < 0) {
/* will retry on the same position */
cb->args[1]--;
break;
}
cond_resched();
}
}
static void mptcp_diag_get_info(struct sock *sk, struct inet_diag_msg *r,
void *_info)
{
struct mptcp_sock *msk = mptcp_sk(sk);
struct mptcp_info *info = _info;
u32 flags = 0;
bool slow;
u8 val;
r->idiag_rqueue = sk_rmem_alloc_get(sk);
r->idiag_wqueue = sk_wmem_alloc_get(sk);
if (!info)
return;
slow = lock_sock_fast(sk);
info->mptcpi_subflows = READ_ONCE(msk->pm.subflows);
info->mptcpi_add_addr_signal = READ_ONCE(msk->pm.add_addr_signaled);
info->mptcpi_add_addr_accepted = READ_ONCE(msk->pm.add_addr_accepted);
info->mptcpi_subflows_max = READ_ONCE(msk->pm.subflows_max);
val = READ_ONCE(msk->pm.add_addr_signal_max);
info->mptcpi_add_addr_signal_max = val;
val = READ_ONCE(msk->pm.add_addr_accept_max);
info->mptcpi_add_addr_accepted_max = val;
if (test_bit(MPTCP_FALLBACK_DONE, &msk->flags))
flags |= MPTCP_INFO_FLAG_FALLBACK;
if (READ_ONCE(msk->can_ack))
flags |= MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED;
info->mptcpi_flags = flags;
info->mptcpi_token = READ_ONCE(msk->token);
info->mptcpi_write_seq = READ_ONCE(msk->write_seq);
info->mptcpi_snd_una = atomic64_read(&msk->snd_una);
info->mptcpi_rcv_nxt = READ_ONCE(msk->ack_seq);
unlock_sock_fast(sk, slow);
}
static const struct inet_diag_handler mptcp_diag_handler = {
.dump = mptcp_diag_dump,
.dump_one = mptcp_diag_dump_one,
.idiag_get_info = mptcp_diag_get_info,
.idiag_type = IPPROTO_MPTCP,
.idiag_info_size = sizeof(struct mptcp_info),
};
static int __init mptcp_diag_init(void)
{
return inet_diag_register(&mptcp_diag_handler);
}
static void __exit mptcp_diag_exit(void)
{
inet_diag_unregister(&mptcp_diag_handler);
}
module_init(mptcp_diag_init);
module_exit(mptcp_diag_exit);
MODULE_LICENSE("GPL");
MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-262 /* AF_INET - IPPROTO_MPTCP */);
......@@ -391,6 +391,8 @@ int mptcp_token_new_connect(struct sock *sk);
void mptcp_token_accept(struct mptcp_subflow_request_sock *r,
struct mptcp_sock *msk);
struct mptcp_sock *mptcp_token_get_sock(u32 token);
struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot,
long *s_num);
void mptcp_token_destroy(struct mptcp_sock *msk);
void mptcp_crypto_key_sha(u64 key, u32 *token, u64 *idsn);
......
......@@ -238,6 +238,66 @@ struct mptcp_sock *mptcp_token_get_sock(u32 token)
rcu_read_unlock();
return msk;
}
EXPORT_SYMBOL_GPL(mptcp_token_get_sock);
/**
* mptcp_token_iter_next - iterate over the token container from given pos
* @net: namespace to be iterated
* @s_slot: start slot number
* @s_num: start number inside the given lock
*
* This function returns the first mptcp connection structure found inside the
* token container starting from the specified position, or NULL.
*
* On successful iteration, the iterator is move to the next position and the
* the acquires a reference to the returned socket.
*/
struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot,
long *s_num)
{
struct mptcp_sock *ret = NULL;
struct hlist_nulls_node *pos;
int slot, num;
for (slot = *s_slot; slot <= token_mask; *s_num = 0, slot++) {
struct token_bucket *bucket = &token_hash[slot];
struct sock *sk;
num = 0;
if (hlist_nulls_empty(&bucket->msk_chain))
continue;
rcu_read_lock();
sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
++num;
if (!net_eq(sock_net(sk), net))
continue;
if (num <= *s_num)
continue;
if (!refcount_inc_not_zero(&sk->sk_refcnt))
continue;
if (!net_eq(sock_net(sk), net)) {
sock_put(sk);
continue;
}
ret = mptcp_sk(sk);
rcu_read_unlock();
goto out;
}
rcu_read_unlock();
}
out:
*s_slot = slot;
*s_num = num;
return ret;
}
EXPORT_SYMBOL_GPL(mptcp_token_iter_next);
/**
* mptcp_token_destroy_request - remove mptcp connection/token
......@@ -312,7 +372,6 @@ void __init mptcp_token_init(void)
EXPORT_SYMBOL_GPL(mptcp_token_new_request);
EXPORT_SYMBOL_GPL(mptcp_token_new_connect);
EXPORT_SYMBOL_GPL(mptcp_token_accept);
EXPORT_SYMBOL_GPL(mptcp_token_get_sock);
EXPORT_SYMBOL_GPL(mptcp_token_destroy_request);
EXPORT_SYMBOL_GPL(mptcp_token_destroy);
#endif
......@@ -5,7 +5,7 @@ KSFT_KHDR_INSTALL := 1
CFLAGS = -Wall -Wl,--no-as-needed -O2 -g -I$(top_srcdir)/usr/include
TEST_PROGS := mptcp_connect.sh pm_netlink.sh mptcp_join.sh
TEST_PROGS := mptcp_connect.sh pm_netlink.sh mptcp_join.sh diag.sh
TEST_GEN_FILES = mptcp_connect pm_nl_ctl
......
#!/bin/bash
# SPDX-License-Identifier: GPL-2.0
rndh=$(printf %x $sec)-$(mktemp -u XXXXXX)
ns="ns1-$rndh"
ksft_skip=4
test_cnt=1
ret=0
pids=()
flush_pids()
{
# mptcp_connect in join mode will sleep a bit before completing,
# give it some time
sleep 1.1
for pid in ${pids[@]}; do
[ -d /proc/$pid ] && kill -SIGUSR1 $pid >/dev/null 2>&1
done
pids=()
}
cleanup()
{
ip netns del $ns
for pid in ${pids[@]}; do
[ -d /proc/$pid ] && kill -9 $pid >/dev/null 2>&1
done
}
ip -Version > /dev/null 2>&1
if [ $? -ne 0 ];then
echo "SKIP: Could not run test without ip tool"
exit $ksft_skip
fi
ss -h | grep -q MPTCP
if [ $? -ne 0 ];then
echo "SKIP: ss tool does not support MPTCP"
exit $ksft_skip
fi
__chk_nr()
{
local condition="$1"
local expected=$2
local msg nr
shift 2
msg=$*
nr=$(ss -inmHMN $ns | $condition)
printf "%-50s" "$msg"
if [ $nr != $expected ]; then
echo "[ fail ] expected $expected found $nr"
ret=$test_cnt
else
echo "[ ok ]"
fi
test_cnt=$((test_cnt+1))
}
chk_msk_nr()
{
__chk_nr "grep -c token:" $*
}
chk_msk_fallback_nr()
{
__chk_nr "grep -c fallback" $*
}
chk_msk_remote_key_nr()
{
__chk_nr "grep -c remote_key" $*
}
trap cleanup EXIT
ip netns add $ns
ip -n $ns link set dev lo up
echo "a" | ip netns exec $ns ./mptcp_connect -p 10000 -l 0.0.0.0 -t 100 >/dev/null &
sleep 0.1
pids[0]=$!
chk_msk_nr 0 "no msk on netns creation"
echo "b" | ip netns exec $ns ./mptcp_connect -p 10000 127.0.0.1 -j -t 100 >/dev/null &
sleep 0.1
pids[1]=$!
chk_msk_nr 2 "after MPC handshake "
chk_msk_remote_key_nr 2 "....chk remote_key"
chk_msk_fallback_nr 0 "....chk no fallback"
flush_pids
echo "a" | ip netns exec $ns ./mptcp_connect -p 10001 -s TCP -l 0.0.0.0 -t 100 >/dev/null &
pids[0]=$!
sleep 0.1
echo "b" | ip netns exec $ns ./mptcp_connect -p 10001 127.0.0.1 -j -t 100 >/dev/null &
pids[1]=$!
sleep 0.1
chk_msk_fallback_nr 1 "check fallback"
flush_pids
NR_CLIENTS=100
for I in `seq 1 $NR_CLIENTS`; do
echo "a" | ip netns exec $ns ./mptcp_connect -p $((I+10001)) -l 0.0.0.0 -t 100 -w 10 >/dev/null &
pids[$((I*2))]=$!
done
sleep 0.1
for I in `seq 1 $NR_CLIENTS`; do
echo "b" | ip netns exec $ns ./mptcp_connect -p $((I+10001)) 127.0.0.1 -t 100 -w 10 >/dev/null &
pids[$((I*2 + 1))]=$!
done
sleep 1.5
chk_msk_nr $((NR_CLIENTS*2)) "many msk socket present"
flush_pids
exit $ret
......@@ -11,6 +11,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <strings.h>
#include <signal.h>
#include <unistd.h>
#include <sys/poll.h>
......@@ -36,6 +37,7 @@ extern int optind;
static int poll_timeout = 10 * 1000;
static bool listen_mode;
static bool quit;
enum cfg_mode {
CFG_MODE_POLL,
......@@ -52,11 +54,12 @@ static int pf = AF_INET;
static int cfg_sndbuf;
static int cfg_rcvbuf;
static bool cfg_join;
static int cfg_wait;
static void die_usage(void)
{
fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] [-m mode]"
"[-l] connect_address\n");
"[-l] [-w sec] connect_address\n");
fprintf(stderr, "\t-6 use ipv6\n");
fprintf(stderr, "\t-t num -- set poll timeout to num\n");
fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n");
......@@ -65,9 +68,15 @@ static void die_usage(void)
fprintf(stderr, "\t-m [MPTCP|TCP] -- use tcp or mptcp sockets\n");
fprintf(stderr, "\t-s [mmap|poll] -- use poll (default) or mmap\n");
fprintf(stderr, "\t-u -- check mptcp ulp\n");
fprintf(stderr, "\t-w num -- wait num sec before closing the socket\n");
exit(1);
}
static void handle_signal(int nr)
{
quit = true;
}
static const char *getxinfo_strerr(int err)
{
if (err == EAI_SYSTEM)
......@@ -418,8 +427,8 @@ static int copyfd_io_poll(int infd, int peerfd, int outfd)
}
/* leave some time for late join/announce */
if (cfg_join)
usleep(400000);
if (cfg_wait)
usleep(cfg_wait);
close(peerfd);
return 0;
......@@ -812,11 +821,12 @@ static void parse_opts(int argc, char **argv)
{
int c;
while ((c = getopt(argc, argv, "6jlp:s:hut:m:S:R:")) != -1) {
while ((c = getopt(argc, argv, "6jlp:s:hut:m:S:R:w:")) != -1) {
switch (c) {
case 'j':
cfg_join = true;
cfg_mode = CFG_MODE_POLL;
cfg_wait = 400000;
break;
case 'l':
listen_mode = true;
......@@ -850,6 +860,9 @@ static void parse_opts(int argc, char **argv)
case 'R':
cfg_rcvbuf = parse_int(optarg);
break;
case 'w':
cfg_wait = atoi(optarg)*1000000;
break;
}
}
......@@ -865,6 +878,7 @@ int main(int argc, char *argv[])
{
init_rng();
signal(SIGUSR1, handle_signal);
parse_opts(argc, argv);
if (tcpulp_audit)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment