bpf_struct_ops.c 32.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
// SPDX-License-Identifier: GPL-2.0-only
/* Copyright (c) 2019 Facebook */

#include <linux/bpf.h>
#include <linux/bpf_verifier.h>
#include <linux/btf.h>
#include <linux/filter.h>
#include <linux/slab.h>
#include <linux/numa.h>
#include <linux/seq_file.h>
#include <linux/refcount.h>
12
#include <linux/mutex.h>
13
#include <linux/btf_ids.h>
14
#include <linux/rcupdate_wait.h>
15

16
struct bpf_struct_ops_value {
Kui-Feng Lee's avatar
Kui-Feng Lee committed
17
	struct bpf_struct_ops_common_value common;
18
	char data[] ____cacheline_aligned_in_smp;
19 20
};

21 22
#define MAX_TRAMP_IMAGE_PAGES 8

23 24
struct bpf_struct_ops_map {
	struct bpf_map map;
25
	struct rcu_head rcu;
26
	const struct bpf_struct_ops_desc *st_ops_desc;
27 28
	/* protect map_update */
	struct mutex lock;
29
	/* link has all the bpf_links that is populated
30 31 32
	 * to the func ptr of the kernel's struct
	 * (in kvalue.data).
	 */
33
	struct bpf_link **links;
34
	u32 links_cnt;
35 36
	u32 image_pages_cnt;
	/* image_pages is an array of pages that has all the trampolines
37 38
	 * that stores the func args before calling the bpf_prog.
	 */
39
	void *image_pages[MAX_TRAMP_IMAGE_PAGES];
40 41
	/* The owner moduler's btf. */
	struct btf *btf;
42 43 44 45 46 47 48 49 50 51 52 53 54 55
	/* uvalue->data stores the kernel struct
	 * (e.g. tcp_congestion_ops) that is more useful
	 * to userspace than the kvalue.  For example,
	 * the bpf_prog's id is stored instead of the kernel
	 * address of a func ptr.
	 */
	struct bpf_struct_ops_value *uvalue;
	/* kvalue.data stores the actual kernel's struct
	 * (e.g. tcp_congestion_ops) that will be
	 * registered to the kernel subsystem.
	 */
	struct bpf_struct_ops_value kvalue;
};

56 57 58 59 60
struct bpf_struct_ops_link {
	struct bpf_link link;
	struct bpf_map __rcu *map;
};

61 62
static DEFINE_MUTEX(update_mutex);

63 64 65
#define VALUE_PREFIX "bpf_struct_ops_"
#define VALUE_PREFIX_LEN (sizeof(VALUE_PREFIX) - 1)

66 67 68 69
const struct bpf_verifier_ops bpf_struct_ops_verifier_ops = {
};

const struct bpf_prog_ops bpf_struct_ops_prog_ops = {
70 71 72
#ifdef CONFIG_NET
	.test_run = bpf_struct_ops_test_run,
#endif
73 74
};

75 76
BTF_ID_LIST(st_ops_ids)
BTF_ID(struct, module)
Kui-Feng Lee's avatar
Kui-Feng Lee committed
77
BTF_ID(struct, bpf_struct_ops_common_value)
78 79 80

enum {
	IDX_MODULE_ID,
Kui-Feng Lee's avatar
Kui-Feng Lee committed
81
	IDX_ST_OPS_COMMON_VALUE_ID,
82
};
83

Kui-Feng Lee's avatar
Kui-Feng Lee committed
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
extern struct btf *btf_vmlinux;

static bool is_valid_value_type(struct btf *btf, s32 value_id,
				const struct btf_type *type,
				const char *value_name)
{
	const struct btf_type *common_value_type;
	const struct btf_member *member;
	const struct btf_type *vt, *mt;

	vt = btf_type_by_id(btf, value_id);
	if (btf_vlen(vt) != 2) {
		pr_warn("The number of %s's members should be 2, but we get %d\n",
			value_name, btf_vlen(vt));
		return false;
	}
	member = btf_type_member(vt);
	mt = btf_type_by_id(btf, member->type);
	common_value_type = btf_type_by_id(btf_vmlinux,
					   st_ops_ids[IDX_ST_OPS_COMMON_VALUE_ID]);
	if (mt != common_value_type) {
		pr_warn("The first member of %s should be bpf_struct_ops_common_value\n",
			value_name);
		return false;
	}
	member++;
	mt = btf_type_by_id(btf, member->type);
	if (mt != type) {
		pr_warn("The second member of %s should be %s\n",
			value_name, btf_name_by_offset(btf, type->name_off));
		return false;
	}

	return true;
}

120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
static void *bpf_struct_ops_image_alloc(void)
{
	void *image;
	int err;

	err = bpf_jit_charge_modmem(PAGE_SIZE);
	if (err)
		return ERR_PTR(err);
	image = arch_alloc_bpf_trampoline(PAGE_SIZE);
	if (!image) {
		bpf_jit_uncharge_modmem(PAGE_SIZE);
		return ERR_PTR(-ENOMEM);
	}

	return image;
}

void bpf_struct_ops_image_free(void *image)
{
	if (image) {
		arch_free_bpf_trampoline(image, PAGE_SIZE);
		bpf_jit_uncharge_modmem(PAGE_SIZE);
	}
}

145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
#define MAYBE_NULL_SUFFIX "__nullable"
#define MAX_STUB_NAME 128

/* Return the type info of a stub function, if it exists.
 *
 * The name of a stub function is made up of the name of the struct_ops and
 * the name of the function pointer member, separated by "__". For example,
 * if the struct_ops type is named "foo_ops" and the function pointer
 * member is named "bar", the stub function name would be "foo_ops__bar".
 */
static const struct btf_type *
find_stub_func_proto(const struct btf *btf, const char *st_op_name,
		     const char *member_name)
{
	char stub_func_name[MAX_STUB_NAME];
	const struct btf_type *func_type;
	s32 btf_id;
	int cp;

	cp = snprintf(stub_func_name, MAX_STUB_NAME, "%s__%s",
		      st_op_name, member_name);
	if (cp >= MAX_STUB_NAME) {
		pr_warn("Stub function name too long\n");
		return NULL;
	}
	btf_id = btf_find_by_name_kind(btf, stub_func_name, BTF_KIND_FUNC);
	if (btf_id < 0)
		return NULL;
	func_type = btf_type_by_id(btf, btf_id);
	if (!func_type)
		return NULL;

	return btf_type_by_id(btf, func_type->type); /* FUNC_PROTO */
}

/* Prepare argument info for every nullable argument of a member of a
 * struct_ops type.
 *
 * Initialize a struct bpf_struct_ops_arg_info according to type info of
 * the arguments of a stub function. (Check kCFI for more information about
 * stub functions.)
 *
 * Each member in the struct_ops type has a struct bpf_struct_ops_arg_info
 * to provide an array of struct bpf_ctx_arg_aux, which in turn provides
 * the information that used by the verifier to check the arguments of the
 * BPF struct_ops program assigned to the member. Here, we only care about
 * the arguments that are marked as __nullable.
 *
 * The array of struct bpf_ctx_arg_aux is eventually assigned to
 * prog->aux->ctx_arg_info of BPF struct_ops programs and passed to the
 * verifier. (See check_struct_ops_btf_id())
 *
 * arg_info->info will be the list of struct bpf_ctx_arg_aux if success. If
 * fails, it will be kept untouched.
 */
static int prepare_arg_info(struct btf *btf,
			    const char *st_ops_name,
			    const char *member_name,
			    const struct btf_type *func_proto,
			    struct bpf_struct_ops_arg_info *arg_info)
{
	const struct btf_type *stub_func_proto, *pointed_type;
	const struct btf_param *stub_args, *args;
	struct bpf_ctx_arg_aux *info, *info_buf;
	u32 nargs, arg_no, info_cnt = 0;
	u32 arg_btf_id;
	int offset;

	stub_func_proto = find_stub_func_proto(btf, st_ops_name, member_name);
	if (!stub_func_proto)
		return 0;

	/* Check if the number of arguments of the stub function is the same
	 * as the number of arguments of the function pointer.
	 */
	nargs = btf_type_vlen(func_proto);
	if (nargs != btf_type_vlen(stub_func_proto)) {
		pr_warn("the number of arguments of the stub function %s__%s does not match the number of arguments of the member %s of struct %s\n",
			st_ops_name, member_name, member_name, st_ops_name);
		return -EINVAL;
	}

	if (!nargs)
		return 0;

	args = btf_params(func_proto);
	stub_args = btf_params(stub_func_proto);

	info_buf = kcalloc(nargs, sizeof(*info_buf), GFP_KERNEL);
	if (!info_buf)
		return -ENOMEM;

	/* Prepare info for every nullable argument */
	info = info_buf;
	for (arg_no = 0; arg_no < nargs; arg_no++) {
		/* Skip arguments that is not suffixed with
		 * "__nullable".
		 */
		if (!btf_param_match_suffix(btf, &stub_args[arg_no],
					    MAYBE_NULL_SUFFIX))
			continue;

		/* Should be a pointer to struct */
		pointed_type = btf_type_resolve_ptr(btf,
						    args[arg_no].type,
						    &arg_btf_id);
		if (!pointed_type ||
		    !btf_type_is_struct(pointed_type)) {
			pr_warn("stub function %s__%s has %s tagging to an unsupported type\n",
				st_ops_name, member_name, MAYBE_NULL_SUFFIX);
			goto err_out;
		}

		offset = btf_ctx_arg_offset(btf, func_proto, arg_no);
		if (offset < 0) {
			pr_warn("stub function %s__%s has an invalid trampoline ctx offset for arg#%u\n",
				st_ops_name, member_name, arg_no);
			goto err_out;
		}

		if (args[arg_no].type != stub_args[arg_no].type) {
			pr_warn("arg#%u type in stub function %s__%s does not match with its original func_proto\n",
				arg_no, st_ops_name, member_name);
			goto err_out;
		}

		/* Fill the information of the new argument */
		info->reg_type =
			PTR_TRUSTED | PTR_TO_BTF_ID | PTR_MAYBE_NULL;
		info->btf_id = arg_btf_id;
		info->btf = btf;
		info->offset = offset;

		info++;
		info_cnt++;
	}

	if (info_cnt) {
		arg_info->info = info_buf;
		arg_info->cnt = info_cnt;
	} else {
		kfree(info_buf);
	}

	return 0;

err_out:
	kfree(info_buf);

	return -EINVAL;
}

/* Clean up the arg_info in a struct bpf_struct_ops_desc. */
void bpf_struct_ops_desc_release(struct bpf_struct_ops_desc *st_ops_desc)
{
	struct bpf_struct_ops_arg_info *arg_info;
	int i;

	arg_info = st_ops_desc->arg_info;
	for (i = 0; i < btf_type_vlen(st_ops_desc->type); i++)
		kfree(arg_info[i].info);

	kfree(arg_info);
}

310 311 312
int bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc,
			     struct btf *btf,
			     struct bpf_verifier_log *log)
313
{
314
	struct bpf_struct_ops *st_ops = st_ops_desc->st_ops;
315
	struct bpf_struct_ops_arg_info *arg_info;
316 317
	const struct btf_member *member;
	const struct btf_type *t;
318
	s32 type_id, value_id;
319
	char value_name[128];
320
	const char *mname;
321
	int i, err;
322

323 324 325 326
	if (strlen(st_ops->name) + VALUE_PREFIX_LEN >=
	    sizeof(value_name)) {
		pr_warn("struct_ops name %s is too long\n",
			st_ops->name);
327
		return -EINVAL;
328 329
	}
	sprintf(value_name, "%s%s", VALUE_PREFIX, st_ops->name);
330

331 332 333 334 335
	if (!st_ops->cfi_stubs) {
		pr_warn("struct_ops for %s has no cfi_stubs\n", st_ops->name);
		return -EINVAL;
	}

336 337 338 339 340
	type_id = btf_find_by_name_kind(btf, st_ops->name,
					BTF_KIND_STRUCT);
	if (type_id < 0) {
		pr_warn("Cannot find struct %s in %s\n",
			st_ops->name, btf_get_name(btf));
341
		return -EINVAL;
342 343 344 345 346
	}
	t = btf_type_by_id(btf, type_id);
	if (btf_type_vlen(t) > BPF_STRUCT_OPS_MAX_NR_MEMBERS) {
		pr_warn("Cannot support #%u members in struct %s\n",
			btf_type_vlen(t), st_ops->name);
347
		return -EINVAL;
348
	}
349

Kui-Feng Lee's avatar
Kui-Feng Lee committed
350 351 352 353 354
	value_id = btf_find_by_name_kind(btf, value_name,
					 BTF_KIND_STRUCT);
	if (value_id < 0) {
		pr_warn("Cannot find struct %s in %s\n",
			value_name, btf_get_name(btf));
355
		return -EINVAL;
Kui-Feng Lee's avatar
Kui-Feng Lee committed
356 357
	}
	if (!is_valid_value_type(btf, value_id, t, value_name))
358
		return -EINVAL;
Kui-Feng Lee's avatar
Kui-Feng Lee committed
359

360 361 362 363 364 365 366 367 368 369 370
	arg_info = kcalloc(btf_type_vlen(t), sizeof(*arg_info),
			   GFP_KERNEL);
	if (!arg_info)
		return -ENOMEM;

	st_ops_desc->arg_info = arg_info;
	st_ops_desc->type = t;
	st_ops_desc->type_id = type_id;
	st_ops_desc->value_id = value_id;
	st_ops_desc->value_type = btf_type_by_id(btf, value_id);

371 372
	for_each_member(i, t, member) {
		const struct btf_type *func_proto;
373

374 375 376
		mname = btf_name_by_offset(btf, member->name_off);
		if (!*mname) {
			pr_warn("anon member in struct %s is not supported\n",
377
				st_ops->name);
378 379
			err = -EOPNOTSUPP;
			goto errout;
380
		}
381 382 383 384

		if (__btf_member_bitfield_size(t, member)) {
			pr_warn("bit field member %s in struct %s is not supported\n",
				mname, st_ops->name);
385 386
			err = -EOPNOTSUPP;
			goto errout;
387 388
		}

389 390 391
		func_proto = btf_type_resolve_func_ptr(btf,
						       member->type,
						       NULL);
392 393 394 395
		if (!func_proto)
			continue;

		if (btf_distill_func_proto(log, btf,
396 397 398 399
					   func_proto, mname,
					   &st_ops->func_models[i])) {
			pr_warn("Error in parsing func ptr %s in struct %s\n",
				mname, st_ops->name);
400 401
			err = -EINVAL;
			goto errout;
402
		}
403 404 405 406 407 408

		err = prepare_arg_info(btf, st_ops->name, mname,
				       func_proto,
				       arg_info + i);
		if (err)
			goto errout;
409
	}
410

411 412 413
	if (st_ops->init(btf)) {
		pr_warn("Error in init bpf_struct_ops %s\n",
			st_ops->name);
414 415
		err = -EINVAL;
		goto errout;
416
	}
417

418
	return 0;
419 420 421 422 423

errout:
	bpf_struct_ops_desc_release(st_ops_desc);

	return err;
424
}
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441

static int bpf_struct_ops_map_get_next_key(struct bpf_map *map, void *key,
					   void *next_key)
{
	if (key && *(u32 *)key == 0)
		return -ENOENT;

	*(u32 *)next_key = 0;
	return 0;
}

int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key,
				       void *value)
{
	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
	struct bpf_struct_ops_value *uvalue, *kvalue;
	enum bpf_struct_ops_state state;
442
	s64 refcnt;
443 444 445 446 447 448

	if (unlikely(*(u32 *)key != 0))
		return -ENOENT;

	kvalue = &st_map->kvalue;
	/* Pair with smp_store_release() during map_update */
Kui-Feng Lee's avatar
Kui-Feng Lee committed
449
	state = smp_load_acquire(&kvalue->common.state);
450 451 452 453 454 455 456 457
	if (state == BPF_STRUCT_OPS_STATE_INIT) {
		memset(value, 0, map->value_size);
		return 0;
	}

	/* No lock is needed.  state and refcnt do not need
	 * to be updated together under atomic context.
	 */
458
	uvalue = value;
459
	memcpy(uvalue, st_map->uvalue, map->value_size);
Kui-Feng Lee's avatar
Kui-Feng Lee committed
460
	uvalue->common.state = state;
461 462 463 464 465 466 467

	/* This value offers the user space a general estimate of how
	 * many sockets are still utilizing this struct_ops for TCP
	 * congestion control. The number might not be exact, but it
	 * should sufficiently meet our present goals.
	 */
	refcnt = atomic64_read(&map->refcnt) - atomic64_read(&map->usercnt);
Kui-Feng Lee's avatar
Kui-Feng Lee committed
468
	refcount_set(&uvalue->common.refcnt, max_t(s64, refcnt, 0));
469 470 471 472 473 474 475 476 477 478 479 480 481

	return 0;
}

static void *bpf_struct_ops_map_lookup_elem(struct bpf_map *map, void *key)
{
	return ERR_PTR(-EINVAL);
}

static void bpf_struct_ops_map_put_progs(struct bpf_struct_ops_map *st_map)
{
	u32 i;

482
	for (i = 0; i < st_map->links_cnt; i++) {
483 484 485
		if (st_map->links[i]) {
			bpf_link_put(st_map->links[i]);
			st_map->links[i] = NULL;
486 487 488 489
		}
	}
}

490 491 492 493 494 495 496 497 498
static void bpf_struct_ops_map_free_image(struct bpf_struct_ops_map *st_map)
{
	int i;

	for (i = 0; i < st_map->image_pages_cnt; i++)
		bpf_struct_ops_image_free(st_map->image_pages[i]);
	st_map->image_pages_cnt = 0;
}

499
static int check_zero_holes(const struct btf *btf, const struct btf_type *t, void *data)
500 501 502 503 504 505
{
	const struct btf_member *member;
	u32 i, moff, msize, prev_mend = 0;
	const struct btf_type *mtype;

	for_each_member(i, t, member) {
506
		moff = __btf_member_bit_offset(t, member) / 8;
507 508 509 510
		if (moff > prev_mend &&
		    memchr_inv(data + prev_mend, 0, moff - prev_mend))
			return -EINVAL;

511 512
		mtype = btf_type_by_id(btf, member->type);
		mtype = btf_resolve_size(btf, mtype, &msize);
513 514 515 516 517 518 519 520 521 522 523 524
		if (IS_ERR(mtype))
			return PTR_ERR(mtype);
		prev_mend = moff + msize;
	}

	if (t->size > prev_mend &&
	    memchr_inv(data + prev_mend, 0, t->size - prev_mend))
		return -EINVAL;

	return 0;
}

525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
static void bpf_struct_ops_link_release(struct bpf_link *link)
{
}

static void bpf_struct_ops_link_dealloc(struct bpf_link *link)
{
	struct bpf_tramp_link *tlink = container_of(link, struct bpf_tramp_link, link);

	kfree(tlink);
}

const struct bpf_link_ops bpf_struct_ops_link_lops = {
	.release = bpf_struct_ops_link_release,
	.dealloc = bpf_struct_ops_link_dealloc,
};

int bpf_struct_ops_prepare_trampoline(struct bpf_tramp_links *tlinks,
				      struct bpf_tramp_link *link,
543
				      const struct btf_func_model *model,
544 545 546
				      void *stub_func,
				      void **_image, u32 *_image_off,
				      bool allow_alloc)
547
{
548 549
	u32 image_off = *_image_off, flags = BPF_TRAMP_F_INDIRECT;
	void *image = *_image;
550
	int size;
551

552 553
	tlinks[BPF_TRAMP_FENTRY].links[0] = link;
	tlinks[BPF_TRAMP_FENTRY].nr_links = 1;
554 555 556

	if (model->ret_size > 0)
		flags |= BPF_TRAMP_F_RET_FENTRY_RET;
557 558

	size = arch_bpf_trampoline_size(model, flags, tlinks, NULL);
559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
	if (size <= 0)
		return size ? : -EFAULT;

	/* Allocate image buffer if necessary */
	if (!image || size > PAGE_SIZE - image_off) {
		if (!allow_alloc)
			return -E2BIG;

		image = bpf_struct_ops_image_alloc();
		if (IS_ERR(image))
			return PTR_ERR(image);
		image_off = 0;
	}

	size = arch_prepare_bpf_trampoline(NULL, image + image_off,
					   image + PAGE_SIZE,
575
					   model, flags, tlinks, stub_func);
576 577 578 579 580 581 582 583 584
	if (size <= 0) {
		if (image != *_image)
			bpf_struct_ops_image_free(image);
		return size ? : -EFAULT;
	}

	*_image = image;
	*_image_off = image_off + size;
	return 0;
585 586
}

587 588
static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
					   void *value, u64 flags)
589 590
{
	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
591 592
	const struct bpf_struct_ops_desc *st_ops_desc = st_map->st_ops_desc;
	const struct bpf_struct_ops *st_ops = st_ops_desc->st_ops;
593
	struct bpf_struct_ops_value *uvalue, *kvalue;
594
	const struct btf_type *module_type;
595
	const struct btf_member *member;
596
	const struct btf_type *t = st_ops_desc->type;
597
	struct bpf_tramp_links *tlinks;
598
	void *udata, *kdata;
599
	int prog_fd, err;
600 601
	u32 i, trampoline_start, image_off = 0;
	void *cur_image = NULL, *image = NULL;
602 603 604 605 606 607 608

	if (flags)
		return -EINVAL;

	if (*(u32 *)key != 0)
		return -E2BIG;

609
	err = check_zero_holes(st_map->btf, st_ops_desc->value_type, value);
610 611 612
	if (err)
		return err;

613
	uvalue = value;
614
	err = check_zero_holes(st_map->btf, t, uvalue->data);
615 616 617
	if (err)
		return err;

Kui-Feng Lee's avatar
Kui-Feng Lee committed
618
	if (uvalue->common.state || refcount_read(&uvalue->common.refcnt))
619 620
		return -EINVAL;

621 622
	tlinks = kcalloc(BPF_TRAMP_MAX, sizeof(*tlinks), GFP_KERNEL);
	if (!tlinks)
623 624
		return -ENOMEM;

625 626 627 628 629
	uvalue = (struct bpf_struct_ops_value *)st_map->uvalue;
	kvalue = (struct bpf_struct_ops_value *)&st_map->kvalue;

	mutex_lock(&st_map->lock);

Kui-Feng Lee's avatar
Kui-Feng Lee committed
630
	if (kvalue->common.state != BPF_STRUCT_OPS_STATE_INIT) {
631 632 633 634 635 636 637 638 639
		err = -EBUSY;
		goto unlock;
	}

	memcpy(uvalue, value, map->value_size);

	udata = &uvalue->data;
	kdata = &kvalue->data;

640
	module_type = btf_type_by_id(btf_vmlinux, st_ops_ids[IDX_MODULE_ID]);
641 642 643
	for_each_member(i, t, member) {
		const struct btf_type *mtype, *ptype;
		struct bpf_prog *prog;
644
		struct bpf_tramp_link *link;
645 646
		u32 moff;

647
		moff = __btf_member_bit_offset(t, member) / 8;
648
		ptype = btf_type_resolve_ptr(st_map->btf, member->type, NULL);
649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
		if (ptype == module_type) {
			if (*(void **)(udata + moff))
				goto reset_unlock;
			*(void **)(kdata + moff) = BPF_MODULE_OWNER;
			continue;
		}

		err = st_ops->init_member(t, member, kdata, udata);
		if (err < 0)
			goto reset_unlock;

		/* The ->init_member() has handled this member */
		if (err > 0)
			continue;

		/* If st_ops->init_member does not handle it,
		 * we will only handle func ptrs and zero-ed members
		 * here.  Reject everything else.
		 */

		/* All non func ptr member must be 0 */
		if (!ptype || !btf_type_is_func_proto(ptype)) {
			u32 msize;

673 674
			mtype = btf_type_by_id(st_map->btf, member->type);
			mtype = btf_resolve_size(st_map->btf, mtype, &msize);
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699
			if (IS_ERR(mtype)) {
				err = PTR_ERR(mtype);
				goto reset_unlock;
			}

			if (memchr_inv(udata + moff, 0, msize)) {
				err = -EINVAL;
				goto reset_unlock;
			}

			continue;
		}

		prog_fd = (int)(*(unsigned long *)(udata + moff));
		/* Similar check as the attr->attach_prog_fd */
		if (!prog_fd)
			continue;

		prog = bpf_prog_get(prog_fd);
		if (IS_ERR(prog)) {
			err = PTR_ERR(prog);
			goto reset_unlock;
		}

		if (prog->type != BPF_PROG_TYPE_STRUCT_OPS ||
700
		    prog->aux->attach_btf_id != st_ops_desc->type_id ||
701
		    prog->expected_attach_type != i) {
702
			bpf_prog_put(prog);
703 704 705 706
			err = -EINVAL;
			goto reset_unlock;
		}

707 708 709 710 711 712 713 714 715 716
		link = kzalloc(sizeof(*link), GFP_USER);
		if (!link) {
			bpf_prog_put(prog);
			err = -ENOMEM;
			goto reset_unlock;
		}
		bpf_link_init(&link->link, BPF_LINK_TYPE_STRUCT_OPS,
			      &bpf_struct_ops_link_lops, prog);
		st_map->links[i] = &link->link;

717
		trampoline_start = image_off;
718
		err = bpf_struct_ops_prepare_trampoline(tlinks, link,
719 720 721 722 723 724 725 726 727 728 729 730
						&st_ops->func_models[i],
						*(void **)(st_ops->cfi_stubs + moff),
						&image, &image_off,
						st_map->image_pages_cnt < MAX_TRAMP_IMAGE_PAGES);
		if (err)
			goto reset_unlock;

		if (cur_image != image) {
			st_map->image_pages[st_map->image_pages_cnt++] = image;
			cur_image = image;
			trampoline_start = 0;
		}
731

732
		*(void **)(kdata + moff) = image + trampoline_start + cfi_get_offset();
733 734 735 736 737

		/* put prog_id to udata */
		*(unsigned long *)(udata + moff) = prog->aux->id;
	}

738 739 740 741 742
	if (st_ops->validate) {
		err = st_ops->validate(kdata);
		if (err)
			goto reset_unlock;
	}
743 744 745 746 747 748
	for (i = 0; i < st_map->image_pages_cnt; i++) {
		err = arch_protect_bpf_trampoline(st_map->image_pages[i],
						  PAGE_SIZE);
		if (err)
			goto reset_unlock;
	}
749

750
	if (st_map->map.map_flags & BPF_F_LINK) {
751
		err = 0;
752 753 754 755
		/* Let bpf_link handle registration & unregistration.
		 *
		 * Pair with smp_load_acquire() during lookup_elem().
		 */
Kui-Feng Lee's avatar
Kui-Feng Lee committed
756
		smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_READY);
757 758
		goto unlock;
	}
759

760
	err = st_ops->reg(kdata, NULL);
761
	if (likely(!err)) {
762 763 764 765 766 767 768
		/* This refcnt increment on the map here after
		 * 'st_ops->reg()' is secure since the state of the
		 * map must be set to INIT at this moment, and thus
		 * bpf_struct_ops_map_delete_elem() can't unregister
		 * or transition it to TOBEFREE concurrently.
		 */
		bpf_map_inc(map);
769 770 771 772
		/* Pair with smp_load_acquire() during lookup_elem().
		 * It ensures the above udata updates (e.g. prog->aux->id)
		 * can be seen once BPF_STRUCT_OPS_STATE_INUSE is set.
		 */
Kui-Feng Lee's avatar
Kui-Feng Lee committed
773
		smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_INUSE);
774 775 776
		goto unlock;
	}

777 778 779
	/* Error during st_ops->reg(). Can happen if this struct_ops needs to be
	 * verified as a whole, after all init_member() calls. Can also happen if
	 * there was a race in registering the struct_ops (under the same name) to
780 781 782 783
	 * a sub-system through different struct_ops's maps.
	 */

reset_unlock:
784
	bpf_struct_ops_map_free_image(st_map);
785 786 787 788
	bpf_struct_ops_map_put_progs(st_map);
	memset(uvalue, 0, map->value_size);
	memset(kvalue, 0, map->value_size);
unlock:
789
	kfree(tlinks);
790 791 792 793
	mutex_unlock(&st_map->lock);
	return err;
}

794
static long bpf_struct_ops_map_delete_elem(struct bpf_map *map, void *key)
795 796 797 798 799
{
	enum bpf_struct_ops_state prev_state;
	struct bpf_struct_ops_map *st_map;

	st_map = (struct bpf_struct_ops_map *)map;
800 801 802
	if (st_map->map.map_flags & BPF_F_LINK)
		return -EOPNOTSUPP;

Kui-Feng Lee's avatar
Kui-Feng Lee committed
803
	prev_state = cmpxchg(&st_map->kvalue.common.state,
804 805
			     BPF_STRUCT_OPS_STATE_INUSE,
			     BPF_STRUCT_OPS_STATE_TOBEFREE);
806 807
	switch (prev_state) {
	case BPF_STRUCT_OPS_STATE_INUSE:
808
		st_map->st_ops_desc->st_ops->unreg(&st_map->kvalue.data, NULL);
809
		bpf_map_put(map);
810 811 812 813 814 815 816 817 818
		return 0;
	case BPF_STRUCT_OPS_STATE_TOBEFREE:
		return -EINPROGRESS;
	case BPF_STRUCT_OPS_STATE_INIT:
		return -ENOENT;
	default:
		WARN_ON_ONCE(1);
		/* Should never happen.  Treat it as not found. */
		return -ENOENT;
819 820 821 822 823 824
	}
}

static void bpf_struct_ops_map_seq_show_elem(struct bpf_map *map, void *key,
					     struct seq_file *m)
{
825
	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
826
	void *value;
827
	int err;
828

829
	value = kmalloc(map->value_size, GFP_USER | __GFP_NOWARN);
830 831 832
	if (!value)
		return;

833 834
	err = bpf_struct_ops_map_sys_lookup_elem(map, key, value);
	if (!err) {
835 836
		btf_type_seq_show(st_map->btf,
				  map->btf_vmlinux_value_type_id,
837 838 839 840 841
				  value, m);
		seq_puts(m, "\n");
	}

	kfree(value);
842 843
}

844
static void __bpf_struct_ops_map_free(struct bpf_map *map)
845 846 847
{
	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;

848
	if (st_map->links)
849
		bpf_struct_ops_map_put_progs(st_map);
850
	bpf_map_area_free(st_map->links);
851
	bpf_struct_ops_map_free_image(st_map);
852 853 854 855
	bpf_map_area_free(st_map->uvalue);
	bpf_map_area_free(st_map);
}

856 857
static void bpf_struct_ops_map_free(struct bpf_map *map)
{
858 859 860 861 862 863 864 865 866
	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;

	/* st_ops->owner was acquired during map_alloc to implicitly holds
	 * the btf's refcnt. The acquire was only done when btf_is_module()
	 * st_map->btf cannot be NULL here.
	 */
	if (btf_is_module(st_map->btf))
		module_put(st_map->st_ops_desc->st_ops->owner);

867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886
	/* The struct_ops's function may switch to another struct_ops.
	 *
	 * For example, bpf_tcp_cc_x->init() may switch to
	 * another tcp_cc_y by calling
	 * setsockopt(TCP_CONGESTION, "tcp_cc_y").
	 * During the switch,  bpf_struct_ops_put(tcp_cc_x) is called
	 * and its refcount may reach 0 which then free its
	 * trampoline image while tcp_cc_x is still running.
	 *
	 * A vanilla rcu gp is to wait for all bpf-tcp-cc prog
	 * to finish. bpf-tcp-cc prog is non sleepable.
	 * A rcu_tasks gp is to wait for the last few insn
	 * in the tramopline image to finish before releasing
	 * the trampoline image.
	 */
	synchronize_rcu_mult(call_rcu, call_rcu_tasks);

	__bpf_struct_ops_map_free(map);
}

887 888 889
static int bpf_struct_ops_map_alloc_check(union bpf_attr *attr)
{
	if (attr->key_size != sizeof(unsigned int) || attr->max_entries != 1 ||
890 891
	    (attr->map_flags & ~(BPF_F_LINK | BPF_F_VTYPE_BTF_OBJ_FD)) ||
	    !attr->btf_vmlinux_value_type_id)
892 893 894 895 896 897
		return -EINVAL;
	return 0;
}

static struct bpf_map *bpf_struct_ops_map_alloc(union bpf_attr *attr)
{
898
	const struct bpf_struct_ops_desc *st_ops_desc;
899
	size_t st_map_size;
900 901
	struct bpf_struct_ops_map *st_map;
	const struct btf_type *t, *vt;
902
	struct module *mod = NULL;
903
	struct bpf_map *map;
904
	struct btf *btf;
905
	int ret;
906

907 908 909 910 911 912 913 914 915
	if (attr->map_flags & BPF_F_VTYPE_BTF_OBJ_FD) {
		/* The map holds btf for its whole life time. */
		btf = btf_get_by_fd(attr->value_type_btf_obj_fd);
		if (IS_ERR(btf))
			return ERR_CAST(btf);
		if (!btf_is_module(btf)) {
			btf_put(btf);
			return ERR_PTR(-EINVAL);
		}
916 917 918 919 920 921 922 923

		mod = btf_try_get_module(btf);
		/* mod holds a refcnt to btf. We don't need an extra refcnt
		 * here.
		 */
		btf_put(btf);
		if (!mod)
			return ERR_PTR(-EINVAL);
924 925 926 927
	} else {
		btf = bpf_get_btf_vmlinux();
		if (IS_ERR(btf))
			return ERR_CAST(btf);
928 929
		if (!btf)
			return ERR_PTR(-ENOTSUPP);
930 931 932 933 934 935 936
	}

	st_ops_desc = bpf_struct_ops_find_value(btf, attr->btf_vmlinux_value_type_id);
	if (!st_ops_desc) {
		ret = -ENOTSUPP;
		goto errout;
	}
937

938
	vt = st_ops_desc->value_type;
939 940 941 942
	if (attr->value_size != vt->size) {
		ret = -EINVAL;
		goto errout;
	}
943

944
	t = st_ops_desc->type;
945 946 947 948 949 950 951 952

	st_map_size = sizeof(*st_map) +
		/* kvalue stores the
		 * struct bpf_struct_ops_tcp_congestions_ops
		 */
		(vt->size - sizeof(struct bpf_struct_ops_value));

	st_map = bpf_map_area_alloc(st_map_size, NUMA_NO_NODE);
953 954 955 956
	if (!st_map) {
		ret = -ENOMEM;
		goto errout;
	}
957

958
	st_map->st_ops_desc = st_ops_desc;
959 960 961
	map = &st_map->map;

	st_map->uvalue = bpf_map_area_alloc(vt->size, NUMA_NO_NODE);
962
	st_map->links_cnt = btf_type_vlen(t);
963
	st_map->links =
964
		bpf_map_area_alloc(st_map->links_cnt * sizeof(struct bpf_links *),
965
				   NUMA_NO_NODE);
966
	if (!st_map->uvalue || !st_map->links) {
967 968
		ret = -ENOMEM;
		goto errout_free;
969
	}
970
	st_map->btf = btf;
971

972 973 974 975
	mutex_init(&st_map->lock);
	bpf_map_init_from_attr(map, attr);

	return map;
976 977 978 979

errout_free:
	__bpf_struct_ops_map_free(map);
errout:
980
	module_put(mod);
981 982

	return ERR_PTR(ret);
983 984
}

985 986 987
static u64 bpf_struct_ops_map_mem_usage(const struct bpf_map *map)
{
	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
988 989
	const struct bpf_struct_ops_desc *st_ops_desc = st_map->st_ops_desc;
	const struct btf_type *vt = st_ops_desc->value_type;
990 991 992 993 994 995 996 997 998 999
	u64 usage;

	usage = sizeof(*st_map) +
			vt->size - sizeof(struct bpf_struct_ops_value);
	usage += vt->size;
	usage += btf_type_vlen(vt) * sizeof(struct bpf_links *);
	usage += PAGE_SIZE;
	return usage;
}

1000
BTF_ID_LIST_SINGLE(bpf_struct_ops_map_btf_ids, struct, bpf_struct_ops_map)
1001 1002 1003 1004 1005 1006 1007 1008 1009
const struct bpf_map_ops bpf_struct_ops_map_ops = {
	.map_alloc_check = bpf_struct_ops_map_alloc_check,
	.map_alloc = bpf_struct_ops_map_alloc,
	.map_free = bpf_struct_ops_map_free,
	.map_get_next_key = bpf_struct_ops_map_get_next_key,
	.map_lookup_elem = bpf_struct_ops_map_lookup_elem,
	.map_delete_elem = bpf_struct_ops_map_delete_elem,
	.map_update_elem = bpf_struct_ops_map_update_elem,
	.map_seq_show_elem = bpf_struct_ops_map_seq_show_elem,
1010
	.map_mem_usage = bpf_struct_ops_map_mem_usage,
1011
	.map_btf_id = &bpf_struct_ops_map_btf_ids[0],
1012 1013 1014 1015 1016 1017 1018 1019
};

/* "const void *" because some subsystem is
 * passing a const (e.g. const struct tcp_congestion_ops *)
 */
bool bpf_struct_ops_get(const void *kdata)
{
	struct bpf_struct_ops_value *kvalue;
1020 1021
	struct bpf_struct_ops_map *st_map;
	struct bpf_map *map;
1022 1023

	kvalue = container_of(kdata, struct bpf_struct_ops_value, data);
1024
	st_map = container_of(kvalue, struct bpf_struct_ops_map, kvalue);
1025

1026 1027
	map = __bpf_map_inc_not_zero(&st_map->map, false);
	return !IS_ERR(map);
1028 1029
}

1030 1031 1032
void bpf_struct_ops_put(const void *kdata)
{
	struct bpf_struct_ops_value *kvalue;
1033
	struct bpf_struct_ops_map *st_map;
1034 1035

	kvalue = container_of(kdata, struct bpf_struct_ops_value, data);
1036 1037 1038
	st_map = container_of(kvalue, struct bpf_struct_ops_map, kvalue);

	bpf_map_put(&st_map->map);
1039
}
1040 1041 1042 1043 1044 1045 1046 1047

static bool bpf_struct_ops_valid_to_reg(struct bpf_map *map)
{
	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;

	return map->map_type == BPF_MAP_TYPE_STRUCT_OPS &&
		map->map_flags & BPF_F_LINK &&
		/* Pair with smp_store_release() during map_update */
Kui-Feng Lee's avatar
Kui-Feng Lee committed
1048
		smp_load_acquire(&st_map->kvalue.common.state) == BPF_STRUCT_OPS_STATE_READY;
1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
}

static void bpf_struct_ops_map_link_dealloc(struct bpf_link *link)
{
	struct bpf_struct_ops_link *st_link;
	struct bpf_struct_ops_map *st_map;

	st_link = container_of(link, struct bpf_struct_ops_link, link);
	st_map = (struct bpf_struct_ops_map *)
		rcu_dereference_protected(st_link->map, true);
	if (st_map) {
1060
		st_map->st_ops_desc->st_ops->unreg(&st_map->kvalue.data, link);
1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074
		bpf_map_put(&st_map->map);
	}
	kfree(st_link);
}

static void bpf_struct_ops_map_link_show_fdinfo(const struct bpf_link *link,
					    struct seq_file *seq)
{
	struct bpf_struct_ops_link *st_link;
	struct bpf_map *map;

	st_link = container_of(link, struct bpf_struct_ops_link, link);
	rcu_read_lock();
	map = rcu_dereference(st_link->map);
1075 1076
	if (map)
		seq_printf(seq, "map_id:\t%d\n", map->id);
1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
	rcu_read_unlock();
}

static int bpf_struct_ops_map_link_fill_link_info(const struct bpf_link *link,
					       struct bpf_link_info *info)
{
	struct bpf_struct_ops_link *st_link;
	struct bpf_map *map;

	st_link = container_of(link, struct bpf_struct_ops_link, link);
	rcu_read_lock();
	map = rcu_dereference(st_link->map);
1089 1090
	if (map)
		info->struct_ops.map_id = map->id;
1091 1092 1093 1094
	rcu_read_unlock();
	return 0;
}

1095 1096 1097 1098 1099 1100
static int bpf_struct_ops_map_link_update(struct bpf_link *link, struct bpf_map *new_map,
					  struct bpf_map *expected_old_map)
{
	struct bpf_struct_ops_map *st_map, *old_st_map;
	struct bpf_map *old_map;
	struct bpf_struct_ops_link *st_link;
1101
	int err;
1102 1103 1104 1105 1106 1107 1108

	st_link = container_of(link, struct bpf_struct_ops_link, link);
	st_map = container_of(new_map, struct bpf_struct_ops_map, map);

	if (!bpf_struct_ops_valid_to_reg(new_map))
		return -EINVAL;

1109
	if (!st_map->st_ops_desc->st_ops->update)
1110 1111
		return -EOPNOTSUPP;

1112 1113 1114
	mutex_lock(&update_mutex);

	old_map = rcu_dereference_protected(st_link->map, lockdep_is_held(&update_mutex));
1115 1116 1117 1118
	if (!old_map) {
		err = -ENOLINK;
		goto err_out;
	}
1119 1120 1121 1122 1123 1124 1125
	if (expected_old_map && old_map != expected_old_map) {
		err = -EPERM;
		goto err_out;
	}

	old_st_map = container_of(old_map, struct bpf_struct_ops_map, map);
	/* The new and old struct_ops must be the same type. */
1126
	if (st_map->st_ops_desc != old_st_map->st_ops_desc) {
1127 1128 1129 1130
		err = -EINVAL;
		goto err_out;
	}

1131
	err = st_map->st_ops_desc->st_ops->update(st_map->kvalue.data, old_st_map->kvalue.data, link);
1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144
	if (err)
		goto err_out;

	bpf_map_inc(new_map);
	rcu_assign_pointer(st_link->map, new_map);
	bpf_map_put(old_map);

err_out:
	mutex_unlock(&update_mutex);

	return err;
}

1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172
static int bpf_struct_ops_map_link_detach(struct bpf_link *link)
{
	struct bpf_struct_ops_link *st_link = container_of(link, struct bpf_struct_ops_link, link);
	struct bpf_struct_ops_map *st_map;
	struct bpf_map *map;

	mutex_lock(&update_mutex);

	map = rcu_dereference_protected(st_link->map, lockdep_is_held(&update_mutex));
	if (!map) {
		mutex_unlock(&update_mutex);
		return 0;
	}
	st_map = container_of(map, struct bpf_struct_ops_map, map);

	st_map->st_ops_desc->st_ops->unreg(&st_map->kvalue.data, link);

	RCU_INIT_POINTER(st_link->map, NULL);
	/* Pair with bpf_map_get() in bpf_struct_ops_link_create() or
	 * bpf_map_inc() in bpf_struct_ops_map_link_update().
	 */
	bpf_map_put(&st_map->map);

	mutex_unlock(&update_mutex);

	return 0;
}

1173 1174
static const struct bpf_link_ops bpf_struct_ops_map_lops = {
	.dealloc = bpf_struct_ops_map_link_dealloc,
1175
	.detach = bpf_struct_ops_map_link_detach,
1176 1177
	.show_fdinfo = bpf_struct_ops_map_link_show_fdinfo,
	.fill_link_info = bpf_struct_ops_map_link_fill_link_info,
1178
	.update_map = bpf_struct_ops_map_link_update,
1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189
};

int bpf_struct_ops_link_create(union bpf_attr *attr)
{
	struct bpf_struct_ops_link *link = NULL;
	struct bpf_link_primer link_primer;
	struct bpf_struct_ops_map *st_map;
	struct bpf_map *map;
	int err;

	map = bpf_map_get(attr->link_create.map_fd);
1190 1191
	if (IS_ERR(map))
		return PTR_ERR(map);
1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210

	st_map = (struct bpf_struct_ops_map *)map;

	if (!bpf_struct_ops_valid_to_reg(map)) {
		err = -EINVAL;
		goto err_out;
	}

	link = kzalloc(sizeof(*link), GFP_USER);
	if (!link) {
		err = -ENOMEM;
		goto err_out;
	}
	bpf_link_init(&link->link, BPF_LINK_TYPE_STRUCT_OPS, &bpf_struct_ops_map_lops, NULL);

	err = bpf_link_prime(&link->link, &link_primer);
	if (err)
		goto err_out;

1211 1212 1213 1214
	/* Hold the update_mutex such that the subsystem cannot
	 * do link->ops->detach() before the link is fully initialized.
	 */
	mutex_lock(&update_mutex);
1215
	err = st_map->st_ops_desc->st_ops->reg(st_map->kvalue.data, &link->link);
1216
	if (err) {
1217
		mutex_unlock(&update_mutex);
1218 1219 1220 1221 1222
		bpf_link_cleanup(&link_primer);
		link = NULL;
		goto err_out;
	}
	RCU_INIT_POINTER(link->map, map);
1223
	mutex_unlock(&update_mutex);
1224 1225 1226 1227 1228 1229 1230 1231

	return bpf_link_settle(&link_primer);

err_out:
	bpf_map_put(map);
	kfree(link);
	return err;
}
1232 1233 1234 1235 1236 1237 1238

void bpf_map_struct_ops_info_fill(struct bpf_map_info *info, struct bpf_map *map)
{
	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;

	info->btf_vmlinux_id = btf_obj_id(st_map->btf);
}