Commit 084d0c13 authored by Jakub Kicinski's avatar Jakub Kicinski

Merge branch 'net-packet-make-packet_fanout-arr-size-configurable-up-to-64k'

Tanner Love says:

====================
net/packet: make packet_fanout.arr size configurable up to 64K

First patch makes the change; second patch adds unit tests.
====================

Link: https://lore.kernel.org/r/20201106180741.2839668-1-tannerlove.kernel@gmail.comSigned-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents a3ce2b10 1db32acf
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#ifndef __LINUX_IF_PACKET_H #ifndef __LINUX_IF_PACKET_H
#define __LINUX_IF_PACKET_H #define __LINUX_IF_PACKET_H
#include <asm/byteorder.h>
#include <linux/types.h> #include <linux/types.h>
struct sockaddr_pkt { struct sockaddr_pkt {
...@@ -296,6 +297,17 @@ struct packet_mreq { ...@@ -296,6 +297,17 @@ struct packet_mreq {
unsigned char mr_address[8]; unsigned char mr_address[8];
}; };
struct fanout_args {
#if defined(__LITTLE_ENDIAN_BITFIELD)
__u16 id;
__u16 type_flags;
#else
__u16 type_flags;
__u16 id;
#endif
__u32 max_num_members;
};
#define PACKET_MR_MULTICAST 0 #define PACKET_MR_MULTICAST 0
#define PACKET_MR_PROMISC 1 #define PACKET_MR_PROMISC 1
#define PACKET_MR_ALLMULTI 2 #define PACKET_MR_ALLMULTI 2
......
...@@ -1636,13 +1636,15 @@ static bool fanout_find_new_id(struct sock *sk, u16 *new_id) ...@@ -1636,13 +1636,15 @@ static bool fanout_find_new_id(struct sock *sk, u16 *new_id)
return false; return false;
} }
static int fanout_add(struct sock *sk, u16 id, u16 type_flags) static int fanout_add(struct sock *sk, struct fanout_args *args)
{ {
struct packet_rollover *rollover = NULL; struct packet_rollover *rollover = NULL;
struct packet_sock *po = pkt_sk(sk); struct packet_sock *po = pkt_sk(sk);
u16 type_flags = args->type_flags;
struct packet_fanout *f, *match; struct packet_fanout *f, *match;
u8 type = type_flags & 0xff; u8 type = type_flags & 0xff;
u8 flags = type_flags >> 8; u8 flags = type_flags >> 8;
u16 id = args->id;
int err; int err;
switch (type) { switch (type) {
...@@ -1700,11 +1702,21 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags) ...@@ -1700,11 +1702,21 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
} }
} }
err = -EINVAL; err = -EINVAL;
if (match && match->flags != flags) if (match) {
goto out; if (match->flags != flags)
if (!match) { goto out;
if (args->max_num_members &&
args->max_num_members != match->max_num_members)
goto out;
} else {
if (args->max_num_members > PACKET_FANOUT_MAX)
goto out;
if (!args->max_num_members)
/* legacy PACKET_FANOUT_MAX */
args->max_num_members = 256;
err = -ENOMEM; err = -ENOMEM;
match = kzalloc(sizeof(*match), GFP_KERNEL); match = kvzalloc(struct_size(match, arr, args->max_num_members),
GFP_KERNEL);
if (!match) if (!match)
goto out; goto out;
write_pnet(&match->net, sock_net(sk)); write_pnet(&match->net, sock_net(sk));
...@@ -1720,6 +1732,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags) ...@@ -1720,6 +1732,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
match->prot_hook.func = packet_rcv_fanout; match->prot_hook.func = packet_rcv_fanout;
match->prot_hook.af_packet_priv = match; match->prot_hook.af_packet_priv = match;
match->prot_hook.id_match = match_fanout_group; match->prot_hook.id_match = match_fanout_group;
match->max_num_members = args->max_num_members;
list_add(&match->list, &fanout_list); list_add(&match->list, &fanout_list);
} }
err = -EINVAL; err = -EINVAL;
...@@ -1730,7 +1743,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags) ...@@ -1730,7 +1743,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
match->prot_hook.type == po->prot_hook.type && match->prot_hook.type == po->prot_hook.type &&
match->prot_hook.dev == po->prot_hook.dev) { match->prot_hook.dev == po->prot_hook.dev) {
err = -ENOSPC; err = -ENOSPC;
if (refcount_read(&match->sk_ref) < PACKET_FANOUT_MAX) { if (refcount_read(&match->sk_ref) < match->max_num_members) {
__dev_remove_pack(&po->prot_hook); __dev_remove_pack(&po->prot_hook);
po->fanout = match; po->fanout = match;
po->rollover = rollover; po->rollover = rollover;
...@@ -1744,7 +1757,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags) ...@@ -1744,7 +1757,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
if (err && !refcount_read(&match->sk_ref)) { if (err && !refcount_read(&match->sk_ref)) {
list_del(&match->list); list_del(&match->list);
kfree(match); kvfree(match);
} }
out: out:
...@@ -3075,7 +3088,7 @@ static int packet_release(struct socket *sock) ...@@ -3075,7 +3088,7 @@ static int packet_release(struct socket *sock)
kfree(po->rollover); kfree(po->rollover);
if (f) { if (f) {
fanout_release_data(f); fanout_release_data(f);
kfree(f); kvfree(f);
} }
/* /*
* Now the socket is dead. No more input will appear. * Now the socket is dead. No more input will appear.
...@@ -3866,14 +3879,14 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval, ...@@ -3866,14 +3879,14 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
} }
case PACKET_FANOUT: case PACKET_FANOUT:
{ {
int val; struct fanout_args args = { 0 };
if (optlen != sizeof(val)) if (optlen != sizeof(int) && optlen != sizeof(args))
return -EINVAL; return -EINVAL;
if (copy_from_sockptr(&val, optval, sizeof(val))) if (copy_from_sockptr(&args, optval, optlen))
return -EFAULT; return -EFAULT;
return fanout_add(sk, val & 0xffff, val >> 16); return fanout_add(sk, &args);
} }
case PACKET_FANOUT_DATA: case PACKET_FANOUT_DATA:
{ {
......
...@@ -77,11 +77,12 @@ struct packet_ring_buffer { ...@@ -77,11 +77,12 @@ struct packet_ring_buffer {
}; };
extern struct mutex fanout_mutex; extern struct mutex fanout_mutex;
#define PACKET_FANOUT_MAX 256 #define PACKET_FANOUT_MAX (1 << 16)
struct packet_fanout { struct packet_fanout {
possible_net_t net; possible_net_t net;
unsigned int num_members; unsigned int num_members;
u32 max_num_members;
u16 id; u16 id;
u8 type; u8 type;
u8 flags; u8 flags;
...@@ -90,10 +91,10 @@ struct packet_fanout { ...@@ -90,10 +91,10 @@ struct packet_fanout {
struct bpf_prog __rcu *bpf_prog; struct bpf_prog __rcu *bpf_prog;
}; };
struct list_head list; struct list_head list;
struct sock *arr[PACKET_FANOUT_MAX];
spinlock_t lock; spinlock_t lock;
refcount_t sk_ref; refcount_t sk_ref;
struct packet_type prot_hook ____cacheline_aligned_in_smp; struct packet_type prot_hook ____cacheline_aligned_in_smp;
struct sock *arr[];
}; };
struct packet_rollover { struct packet_rollover {
......
...@@ -56,12 +56,15 @@ ...@@ -56,12 +56,15 @@
#define RING_NUM_FRAMES 20 #define RING_NUM_FRAMES 20
static uint32_t cfg_max_num_members;
/* Open a socket in a given fanout mode. /* Open a socket in a given fanout mode.
* @return -1 if mode is bad, a valid socket otherwise */ * @return -1 if mode is bad, a valid socket otherwise */
static int sock_fanout_open(uint16_t typeflags, uint16_t group_id) static int sock_fanout_open(uint16_t typeflags, uint16_t group_id)
{ {
struct sockaddr_ll addr = {0}; struct sockaddr_ll addr = {0};
int fd, val; struct fanout_args args;
int fd, val, err;
fd = socket(PF_PACKET, SOCK_RAW, 0); fd = socket(PF_PACKET, SOCK_RAW, 0);
if (fd < 0) { if (fd < 0) {
...@@ -83,8 +86,18 @@ static int sock_fanout_open(uint16_t typeflags, uint16_t group_id) ...@@ -83,8 +86,18 @@ static int sock_fanout_open(uint16_t typeflags, uint16_t group_id)
exit(1); exit(1);
} }
val = (((int) typeflags) << 16) | group_id; if (cfg_max_num_members) {
if (setsockopt(fd, SOL_PACKET, PACKET_FANOUT, &val, sizeof(val))) { args.id = group_id;
args.type_flags = typeflags;
args.max_num_members = cfg_max_num_members;
err = setsockopt(fd, SOL_PACKET, PACKET_FANOUT, &args,
sizeof(args));
} else {
val = (((int) typeflags) << 16) | group_id;
err = setsockopt(fd, SOL_PACKET, PACKET_FANOUT, &val,
sizeof(val));
}
if (err) {
if (close(fd)) { if (close(fd)) {
perror("close packet"); perror("close packet");
exit(1); exit(1);
...@@ -286,6 +299,56 @@ static void test_control_group(void) ...@@ -286,6 +299,56 @@ static void test_control_group(void)
} }
} }
/* Test illegal max_num_members values */
static void test_control_group_max_num_members(void)
{
int fds[3];
fprintf(stderr, "test: control multiple sockets, max_num_members\n");
/* expected failure on greater than PACKET_FANOUT_MAX */
cfg_max_num_members = (1 << 16) + 1;
if (sock_fanout_open(PACKET_FANOUT_HASH, 0) != -1) {
fprintf(stderr, "ERROR: max_num_members > PACKET_FANOUT_MAX\n");
exit(1);
}
cfg_max_num_members = 256;
fds[0] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
if (fds[0] == -1) {
fprintf(stderr, "ERROR: failed open\n");
exit(1);
}
/* expected failure on joining group with different max_num_members */
cfg_max_num_members = 257;
if (sock_fanout_open(PACKET_FANOUT_HASH, 0) != -1) {
fprintf(stderr, "ERROR: set different max_num_members\n");
exit(1);
}
/* success on joining group with same max_num_members */
cfg_max_num_members = 256;
fds[1] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
if (fds[1] == -1) {
fprintf(stderr, "ERROR: failed to join group\n");
exit(1);
}
/* success on joining group with max_num_members unspecified */
cfg_max_num_members = 0;
fds[2] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
if (fds[2] == -1) {
fprintf(stderr, "ERROR: failed to join group\n");
exit(1);
}
if (close(fds[2]) || close(fds[1]) || close(fds[0])) {
fprintf(stderr, "ERROR: closing sockets\n");
exit(1);
}
}
/* Test creating a unique fanout group ids */ /* Test creating a unique fanout group ids */
static void test_unique_fanout_group_ids(void) static void test_unique_fanout_group_ids(void)
{ {
...@@ -426,8 +489,11 @@ int main(int argc, char **argv) ...@@ -426,8 +489,11 @@ int main(int argc, char **argv)
test_control_single(); test_control_single();
test_control_group(); test_control_group();
test_control_group_max_num_members();
test_unique_fanout_group_ids(); test_unique_fanout_group_ids();
/* PACKET_FANOUT_MAX */
cfg_max_num_members = 1 << 16;
/* find a set of ports that do not collide onto the same socket */ /* find a set of ports that do not collide onto the same socket */
ret = test_datapath(PACKET_FANOUT_HASH, port_off, ret = test_datapath(PACKET_FANOUT_HASH, port_off,
expect_hash[0], expect_hash[1]); expect_hash[0], expect_hash[1]);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment