Commit 98d7ca37 authored by Alexei Starovoitov's avatar Alexei Starovoitov Committed by Daniel Borkmann

bpf: Track delta between "linked" registers.

Compilers can generate the code
  r1 = r2
  r1 += 0x1
  if r2 < 1000 goto ...
  use knowledge of r2 range in subsequent r1 operations

So remember constant delta between r2 and r1 and update r1 after 'if' condition.

Unfortunately LLVM still uses this pattern for loops with 'can_loop' construct:
for (i = 0; i < 1000 && can_loop; i++)

The "undo" pass was introduced in LLVM
https://reviews.llvm.org/D121937
to prevent this optimization, but it cannot cover all cases.
Instead of fighting middle end optimizer in BPF backend teach the verifier
about this pattern.
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Acked-by: default avatarEduard Zingerman <eddyz87@gmail.com>
Link: https://lore.kernel.org/bpf/20240613013815.953-3-alexei.starovoitov@gmail.com
parent 124e8c2b
...@@ -73,7 +73,10 @@ enum bpf_iter_state { ...@@ -73,7 +73,10 @@ enum bpf_iter_state {
struct bpf_reg_state { struct bpf_reg_state {
/* Ordering of fields matters. See states_equal() */ /* Ordering of fields matters. See states_equal() */
enum bpf_reg_type type; enum bpf_reg_type type;
/* Fixed part of pointer offset, pointer types only */ /*
* Fixed part of pointer offset, pointer types only.
* Or constant delta between "linked" scalars with the same ID.
*/
s32 off; s32 off;
union { union {
/* valid when type == PTR_TO_PACKET */ /* valid when type == PTR_TO_PACKET */
...@@ -167,6 +170,13 @@ struct bpf_reg_state { ...@@ -167,6 +170,13 @@ struct bpf_reg_state {
* Similarly to dynptrs, we use ID to track "belonging" of a reference * Similarly to dynptrs, we use ID to track "belonging" of a reference
* to a specific instance of bpf_iter. * to a specific instance of bpf_iter.
*/ */
/*
* Upper bit of ID is used to remember relationship between "linked"
* registers. Example:
* r1 = r2; both will have r1->id == r2->id == N
* r1 += 10; r1->id == N | BPF_ADD_CONST and r1->off == 10
*/
#define BPF_ADD_CONST (1U << 31)
u32 id; u32 id;
/* PTR_TO_SOCKET and PTR_TO_TCP_SOCK could be a ptr returned /* PTR_TO_SOCKET and PTR_TO_TCP_SOCK could be a ptr returned
* from a pointer-cast helper, bpf_sk_fullsock() and * from a pointer-cast helper, bpf_sk_fullsock() and
......
...@@ -708,7 +708,9 @@ static void print_reg_state(struct bpf_verifier_env *env, ...@@ -708,7 +708,9 @@ static void print_reg_state(struct bpf_verifier_env *env,
verbose(env, "%s", btf_type_name(reg->btf, reg->btf_id)); verbose(env, "%s", btf_type_name(reg->btf, reg->btf_id));
verbose(env, "("); verbose(env, "(");
if (reg->id) if (reg->id)
verbose_a("id=%d", reg->id); verbose_a("id=%d", reg->id & ~BPF_ADD_CONST);
if (reg->id & BPF_ADD_CONST)
verbose(env, "%+d", reg->off);
if (reg->ref_obj_id) if (reg->ref_obj_id)
verbose_a("ref_obj_id=%d", reg->ref_obj_id); verbose_a("ref_obj_id=%d", reg->ref_obj_id);
if (type_is_non_owning_ref(reg->type)) if (type_is_non_owning_ref(reg->type))
......
...@@ -3991,7 +3991,7 @@ static bool idset_contains(struct bpf_idset *s, u32 id) ...@@ -3991,7 +3991,7 @@ static bool idset_contains(struct bpf_idset *s, u32 id)
u32 i; u32 i;
for (i = 0; i < s->count; ++i) for (i = 0; i < s->count; ++i)
if (s->ids[i] == id) if (s->ids[i] == (id & ~BPF_ADD_CONST))
return true; return true;
return false; return false;
...@@ -4001,7 +4001,7 @@ static int idset_push(struct bpf_idset *s, u32 id) ...@@ -4001,7 +4001,7 @@ static int idset_push(struct bpf_idset *s, u32 id)
{ {
if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids))) if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids)))
return -EFAULT; return -EFAULT;
s->ids[s->count++] = id; s->ids[s->count++] = id & ~BPF_ADD_CONST;
return 0; return 0;
} }
...@@ -4438,8 +4438,20 @@ static bool __is_pointer_value(bool allow_ptr_leaks, ...@@ -4438,8 +4438,20 @@ static bool __is_pointer_value(bool allow_ptr_leaks,
static void assign_scalar_id_before_mov(struct bpf_verifier_env *env, static void assign_scalar_id_before_mov(struct bpf_verifier_env *env,
struct bpf_reg_state *src_reg) struct bpf_reg_state *src_reg)
{ {
if (src_reg->type == SCALAR_VALUE && !src_reg->id && if (src_reg->type != SCALAR_VALUE)
!tnum_is_const(src_reg->var_off)) return;
if (src_reg->id & BPF_ADD_CONST) {
/*
* The verifier is processing rX = rY insn and
* rY->id has special linked register already.
* Cleared it, since multiple rX += const are not supported.
*/
src_reg->id = 0;
src_reg->off = 0;
}
if (!src_reg->id && !tnum_is_const(src_reg->var_off))
/* Ensure that src_reg has a valid ID that will be copied to /* Ensure that src_reg has a valid ID that will be copied to
* dst_reg and then will be used by find_equal_scalars() to * dst_reg and then will be used by find_equal_scalars() to
* propagate min/max range. * propagate min/max range.
...@@ -14042,6 +14054,7 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env, ...@@ -14042,6 +14054,7 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,
struct bpf_func_state *state = vstate->frame[vstate->curframe]; struct bpf_func_state *state = vstate->frame[vstate->curframe];
struct bpf_reg_state *regs = state->regs, *dst_reg, *src_reg; struct bpf_reg_state *regs = state->regs, *dst_reg, *src_reg;
struct bpf_reg_state *ptr_reg = NULL, off_reg = {0}; struct bpf_reg_state *ptr_reg = NULL, off_reg = {0};
bool alu32 = (BPF_CLASS(insn->code) != BPF_ALU64);
u8 opcode = BPF_OP(insn->code); u8 opcode = BPF_OP(insn->code);
int err; int err;
...@@ -14064,11 +14077,7 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env, ...@@ -14064,11 +14077,7 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,
if (dst_reg->type != SCALAR_VALUE) if (dst_reg->type != SCALAR_VALUE)
ptr_reg = dst_reg; ptr_reg = dst_reg;
else
/* Make sure ID is cleared otherwise dst_reg min/max could be
* incorrectly propagated into other registers by find_equal_scalars()
*/
dst_reg->id = 0;
if (BPF_SRC(insn->code) == BPF_X) { if (BPF_SRC(insn->code) == BPF_X) {
src_reg = &regs[insn->src_reg]; src_reg = &regs[insn->src_reg];
if (src_reg->type != SCALAR_VALUE) { if (src_reg->type != SCALAR_VALUE) {
...@@ -14132,7 +14141,43 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env, ...@@ -14132,7 +14141,43 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,
verbose(env, "verifier internal error: no src_reg\n"); verbose(env, "verifier internal error: no src_reg\n");
return -EINVAL; return -EINVAL;
} }
return adjust_scalar_min_max_vals(env, insn, dst_reg, *src_reg); err = adjust_scalar_min_max_vals(env, insn, dst_reg, *src_reg);
if (err)
return err;
/*
* Compilers can generate the code
* r1 = r2
* r1 += 0x1
* if r2 < 1000 goto ...
* use r1 in memory access
* So remember constant delta between r2 and r1 and update r1 after
* 'if' condition.
*/
if (env->bpf_capable && BPF_OP(insn->code) == BPF_ADD &&
dst_reg->id && is_reg_const(src_reg, alu32)) {
u64 val = reg_const_value(src_reg, alu32);
if ((dst_reg->id & BPF_ADD_CONST) ||
/* prevent overflow in find_equal_scalars() later */
val > (u32)S32_MAX) {
/*
* If the register already went through rX += val
* we cannot accumulate another val into rx->off.
*/
dst_reg->off = 0;
dst_reg->id = 0;
} else {
dst_reg->id |= BPF_ADD_CONST;
dst_reg->off = val;
}
} else {
/*
* Make sure ID is cleared otherwise dst_reg min/max could be
* incorrectly propagated into other registers by find_equal_scalars()
*/
dst_reg->id = 0;
}
return 0;
} }
/* check validity of 32-bit and 64-bit arithmetic operations */ /* check validity of 32-bit and 64-bit arithmetic operations */
...@@ -15104,12 +15149,36 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn, ...@@ -15104,12 +15149,36 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
static void find_equal_scalars(struct bpf_verifier_state *vstate, static void find_equal_scalars(struct bpf_verifier_state *vstate,
struct bpf_reg_state *known_reg) struct bpf_reg_state *known_reg)
{ {
struct bpf_reg_state fake_reg;
struct bpf_func_state *state; struct bpf_func_state *state;
struct bpf_reg_state *reg; struct bpf_reg_state *reg;
bpf_for_each_reg_in_vstate(vstate, state, reg, ({ bpf_for_each_reg_in_vstate(vstate, state, reg, ({
if (reg->type == SCALAR_VALUE && reg->id == known_reg->id) if (reg->type != SCALAR_VALUE || reg == known_reg)
continue;
if ((reg->id & ~BPF_ADD_CONST) != (known_reg->id & ~BPF_ADD_CONST))
continue;
if ((!(reg->id & BPF_ADD_CONST) && !(known_reg->id & BPF_ADD_CONST)) ||
reg->off == known_reg->off) {
copy_register_state(reg, known_reg);
} else {
s32 saved_off = reg->off;
fake_reg.type = SCALAR_VALUE;
__mark_reg_known(&fake_reg, (s32)reg->off - (s32)known_reg->off);
/* reg = known_reg; reg += delta */
copy_register_state(reg, known_reg); copy_register_state(reg, known_reg);
/*
* Must preserve off, id and add_const flag,
* otherwise another find_equal_scalars() will be incorrect.
*/
reg->off = saved_off;
scalar32_min_max_add(reg, &fake_reg);
scalar_min_max_add(reg, &fake_reg);
reg->var_off = tnum_add(reg->var_off, fake_reg.var_off);
}
})); }));
} }
...@@ -16738,6 +16807,10 @@ static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold, ...@@ -16738,6 +16807,10 @@ static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold,
} }
if (!rold->precise && exact == NOT_EXACT) if (!rold->precise && exact == NOT_EXACT)
return true; return true;
if ((rold->id & BPF_ADD_CONST) != (rcur->id & BPF_ADD_CONST))
return false;
if ((rold->id & BPF_ADD_CONST) && (rold->off != rcur->off))
return false;
/* Why check_ids() for scalar registers? /* Why check_ids() for scalar registers?
* *
* Consider the following BPF code: * Consider the following BPF code:
......
...@@ -39,12 +39,12 @@ ...@@ -39,12 +39,12 @@
.result = VERBOSE_ACCEPT, .result = VERBOSE_ACCEPT,
.errstr = .errstr =
"mark_precise: frame0: last_idx 26 first_idx 20\ "mark_precise: frame0: last_idx 26 first_idx 20\
mark_precise: frame0: regs=r2 stack= before 25\ mark_precise: frame0: regs=r2,r9 stack= before 25\
mark_precise: frame0: regs=r2 stack= before 24\ mark_precise: frame0: regs=r2,r9 stack= before 24\
mark_precise: frame0: regs=r2 stack= before 23\ mark_precise: frame0: regs=r2,r9 stack= before 23\
mark_precise: frame0: regs=r2 stack= before 22\ mark_precise: frame0: regs=r2,r9 stack= before 22\
mark_precise: frame0: regs=r2 stack= before 20\ mark_precise: frame0: regs=r2,r9 stack= before 20\
mark_precise: frame0: parent state regs=r2 stack=:\ mark_precise: frame0: parent state regs=r2,r9 stack=:\
mark_precise: frame0: last_idx 19 first_idx 10\ mark_precise: frame0: last_idx 19 first_idx 10\
mark_precise: frame0: regs=r2,r9 stack= before 19\ mark_precise: frame0: regs=r2,r9 stack= before 19\
mark_precise: frame0: regs=r9 stack= before 18\ mark_precise: frame0: regs=r9 stack= before 18\
...@@ -100,11 +100,11 @@ ...@@ -100,11 +100,11 @@
.errstr = .errstr =
"26: (85) call bpf_probe_read_kernel#113\ "26: (85) call bpf_probe_read_kernel#113\
mark_precise: frame0: last_idx 26 first_idx 22\ mark_precise: frame0: last_idx 26 first_idx 22\
mark_precise: frame0: regs=r2 stack= before 25\ mark_precise: frame0: regs=r2,r9 stack= before 25\
mark_precise: frame0: regs=r2 stack= before 24\ mark_precise: frame0: regs=r2,r9 stack= before 24\
mark_precise: frame0: regs=r2 stack= before 23\ mark_precise: frame0: regs=r2,r9 stack= before 23\
mark_precise: frame0: regs=r2 stack= before 22\ mark_precise: frame0: regs=r2,r9 stack= before 22\
mark_precise: frame0: parent state regs=r2 stack=:\ mark_precise: frame0: parent state regs=r2,r9 stack=:\
mark_precise: frame0: last_idx 20 first_idx 20\ mark_precise: frame0: last_idx 20 first_idx 20\
mark_precise: frame0: regs=r2,r9 stack= before 20\ mark_precise: frame0: regs=r2,r9 stack= before 20\
mark_precise: frame0: parent state regs=r2,r9 stack=:\ mark_precise: frame0: parent state regs=r2,r9 stack=:\
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment