Commit 4d3a453b authored by Ilya Leoshkevich's avatar Ilya Leoshkevich Committed by Daniel Borkmann

s390/bpf: Support BPF_PROBE_MEM32

BPF_PROBE_MEM32 is a new mode for LDX, ST and STX instructions. The JIT
is supposed to add the start address of the kernel arena mapping to the
%dst register, and use a probing variant of the respective memory
access.

Reuse the existing probing infrastructure for that. Put the arena
address into the literal pool, load it into %r1 and use that as an
index register. Do not clear any registers in ex_handler_bpf() for
failing ST and STX instructions.
Signed-off-by: default avatarIlya Leoshkevich <iii@linux.ibm.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Link: https://lore.kernel.org/bpf/20240701234304.14336-7-iii@linux.ibm.com
parent a1c04bcc
...@@ -53,6 +53,7 @@ struct bpf_jit { ...@@ -53,6 +53,7 @@ struct bpf_jit {
int excnt; /* Number of exception table entries */ int excnt; /* Number of exception table entries */
int prologue_plt_ret; /* Return address for prologue hotpatch PLT */ int prologue_plt_ret; /* Return address for prologue hotpatch PLT */
int prologue_plt; /* Start of prologue hotpatch PLT */ int prologue_plt; /* Start of prologue hotpatch PLT */
int kern_arena; /* Pool offset of kernel arena address */
}; };
#define SEEN_MEM BIT(0) /* use mem[] for temporary storage */ #define SEEN_MEM BIT(0) /* use mem[] for temporary storage */
...@@ -670,6 +671,7 @@ static void bpf_jit_epilogue(struct bpf_jit *jit, u32 stack_depth) ...@@ -670,6 +671,7 @@ static void bpf_jit_epilogue(struct bpf_jit *jit, u32 stack_depth)
bool ex_handler_bpf(const struct exception_table_entry *x, struct pt_regs *regs) bool ex_handler_bpf(const struct exception_table_entry *x, struct pt_regs *regs)
{ {
regs->psw.addr = extable_fixup(x); regs->psw.addr = extable_fixup(x);
if (x->data != -1)
regs->gprs[x->data] = 0; regs->gprs[x->data] = 0;
return true; return true;
} }
...@@ -681,6 +683,7 @@ struct bpf_jit_probe { ...@@ -681,6 +683,7 @@ struct bpf_jit_probe {
int prg; /* JITed instruction offset */ int prg; /* JITed instruction offset */
int nop_prg; /* JITed nop offset */ int nop_prg; /* JITed nop offset */
int reg; /* Register to clear on exception */ int reg; /* Register to clear on exception */
int arena_reg; /* Register to use for arena addressing */
}; };
static void bpf_jit_probe_init(struct bpf_jit_probe *probe) static void bpf_jit_probe_init(struct bpf_jit_probe *probe)
...@@ -688,6 +691,7 @@ static void bpf_jit_probe_init(struct bpf_jit_probe *probe) ...@@ -688,6 +691,7 @@ static void bpf_jit_probe_init(struct bpf_jit_probe *probe)
probe->prg = -1; probe->prg = -1;
probe->nop_prg = -1; probe->nop_prg = -1;
probe->reg = -1; probe->reg = -1;
probe->arena_reg = REG_0;
} }
/* /*
...@@ -708,13 +712,31 @@ static void bpf_jit_probe_load_pre(struct bpf_jit *jit, struct bpf_insn *insn, ...@@ -708,13 +712,31 @@ static void bpf_jit_probe_load_pre(struct bpf_jit *jit, struct bpf_insn *insn,
struct bpf_jit_probe *probe) struct bpf_jit_probe *probe)
{ {
if (BPF_MODE(insn->code) != BPF_PROBE_MEM && if (BPF_MODE(insn->code) != BPF_PROBE_MEM &&
BPF_MODE(insn->code) != BPF_PROBE_MEMSX) BPF_MODE(insn->code) != BPF_PROBE_MEMSX &&
BPF_MODE(insn->code) != BPF_PROBE_MEM32)
return; return;
if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
/* lgrl %r1,kern_arena */
EMIT6_PCREL_RILB(0xc4080000, REG_W1, jit->kern_arena);
probe->arena_reg = REG_W1;
}
probe->prg = jit->prg; probe->prg = jit->prg;
probe->reg = reg2hex[insn->dst_reg]; probe->reg = reg2hex[insn->dst_reg];
} }
static void bpf_jit_probe_store_pre(struct bpf_jit *jit, struct bpf_insn *insn,
struct bpf_jit_probe *probe)
{
if (BPF_MODE(insn->code) != BPF_PROBE_MEM32)
return;
/* lgrl %r1,kern_arena */
EMIT6_PCREL_RILB(0xc4080000, REG_W1, jit->kern_arena);
probe->arena_reg = REG_W1;
probe->prg = jit->prg;
}
static int bpf_jit_probe_post(struct bpf_jit *jit, struct bpf_prog *fp, static int bpf_jit_probe_post(struct bpf_jit *jit, struct bpf_prog *fp,
struct bpf_jit_probe *probe) struct bpf_jit_probe *probe)
{ {
...@@ -1384,51 +1406,99 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, ...@@ -1384,51 +1406,99 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
* BPF_ST(X) * BPF_ST(X)
*/ */
case BPF_STX | BPF_MEM | BPF_B: /* *(u8 *)(dst + off) = src_reg */ case BPF_STX | BPF_MEM | BPF_B: /* *(u8 *)(dst + off) = src_reg */
/* stcy %src,off(%dst) */ case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
EMIT6_DISP_LH(0xe3000000, 0x0072, src_reg, dst_reg, REG_0, off); bpf_jit_probe_store_pre(jit, insn, &probe);
/* stcy %src,off(%dst,%arena) */
EMIT6_DISP_LH(0xe3000000, 0x0072, src_reg, dst_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0)
return err;
jit->seen |= SEEN_MEM; jit->seen |= SEEN_MEM;
break; break;
case BPF_STX | BPF_MEM | BPF_H: /* (u16 *)(dst + off) = src */ case BPF_STX | BPF_MEM | BPF_H: /* (u16 *)(dst + off) = src */
/* sthy %src,off(%dst) */ case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
EMIT6_DISP_LH(0xe3000000, 0x0070, src_reg, dst_reg, REG_0, off); bpf_jit_probe_store_pre(jit, insn, &probe);
/* sthy %src,off(%dst,%arena) */
EMIT6_DISP_LH(0xe3000000, 0x0070, src_reg, dst_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0)
return err;
jit->seen |= SEEN_MEM; jit->seen |= SEEN_MEM;
break; break;
case BPF_STX | BPF_MEM | BPF_W: /* *(u32 *)(dst + off) = src */ case BPF_STX | BPF_MEM | BPF_W: /* *(u32 *)(dst + off) = src */
/* sty %src,off(%dst) */ case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
EMIT6_DISP_LH(0xe3000000, 0x0050, src_reg, dst_reg, REG_0, off); bpf_jit_probe_store_pre(jit, insn, &probe);
/* sty %src,off(%dst,%arena) */
EMIT6_DISP_LH(0xe3000000, 0x0050, src_reg, dst_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0)
return err;
jit->seen |= SEEN_MEM; jit->seen |= SEEN_MEM;
break; break;
case BPF_STX | BPF_MEM | BPF_DW: /* (u64 *)(dst + off) = src */ case BPF_STX | BPF_MEM | BPF_DW: /* (u64 *)(dst + off) = src */
/* stg %src,off(%dst) */ case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
EMIT6_DISP_LH(0xe3000000, 0x0024, src_reg, dst_reg, REG_0, off); bpf_jit_probe_store_pre(jit, insn, &probe);
/* stg %src,off(%dst,%arena) */
EMIT6_DISP_LH(0xe3000000, 0x0024, src_reg, dst_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0)
return err;
jit->seen |= SEEN_MEM; jit->seen |= SEEN_MEM;
break; break;
case BPF_ST | BPF_MEM | BPF_B: /* *(u8 *)(dst + off) = imm */ case BPF_ST | BPF_MEM | BPF_B: /* *(u8 *)(dst + off) = imm */
case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
/* lhi %w0,imm */ /* lhi %w0,imm */
EMIT4_IMM(0xa7080000, REG_W0, (u8) imm); EMIT4_IMM(0xa7080000, REG_W0, (u8) imm);
/* stcy %w0,off(dst) */ bpf_jit_probe_store_pre(jit, insn, &probe);
EMIT6_DISP_LH(0xe3000000, 0x0072, REG_W0, dst_reg, REG_0, off); /* stcy %w0,off(%dst,%arena) */
EMIT6_DISP_LH(0xe3000000, 0x0072, REG_W0, dst_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0)
return err;
jit->seen |= SEEN_MEM; jit->seen |= SEEN_MEM;
break; break;
case BPF_ST | BPF_MEM | BPF_H: /* (u16 *)(dst + off) = imm */ case BPF_ST | BPF_MEM | BPF_H: /* (u16 *)(dst + off) = imm */
case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
/* lhi %w0,imm */ /* lhi %w0,imm */
EMIT4_IMM(0xa7080000, REG_W0, (u16) imm); EMIT4_IMM(0xa7080000, REG_W0, (u16) imm);
/* sthy %w0,off(dst) */ bpf_jit_probe_store_pre(jit, insn, &probe);
EMIT6_DISP_LH(0xe3000000, 0x0070, REG_W0, dst_reg, REG_0, off); /* sthy %w0,off(%dst,%arena) */
EMIT6_DISP_LH(0xe3000000, 0x0070, REG_W0, dst_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0)
return err;
jit->seen |= SEEN_MEM; jit->seen |= SEEN_MEM;
break; break;
case BPF_ST | BPF_MEM | BPF_W: /* *(u32 *)(dst + off) = imm */ case BPF_ST | BPF_MEM | BPF_W: /* *(u32 *)(dst + off) = imm */
case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
/* llilf %w0,imm */ /* llilf %w0,imm */
EMIT6_IMM(0xc00f0000, REG_W0, (u32) imm); EMIT6_IMM(0xc00f0000, REG_W0, (u32) imm);
/* sty %w0,off(%dst) */ bpf_jit_probe_store_pre(jit, insn, &probe);
EMIT6_DISP_LH(0xe3000000, 0x0050, REG_W0, dst_reg, REG_0, off); /* sty %w0,off(%dst,%arena) */
EMIT6_DISP_LH(0xe3000000, 0x0050, REG_W0, dst_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0)
return err;
jit->seen |= SEEN_MEM; jit->seen |= SEEN_MEM;
break; break;
case BPF_ST | BPF_MEM | BPF_DW: /* *(u64 *)(dst + off) = imm */ case BPF_ST | BPF_MEM | BPF_DW: /* *(u64 *)(dst + off) = imm */
case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
/* lgfi %w0,imm */ /* lgfi %w0,imm */
EMIT6_IMM(0xc0010000, REG_W0, imm); EMIT6_IMM(0xc0010000, REG_W0, imm);
/* stg %w0,off(%dst) */ bpf_jit_probe_store_pre(jit, insn, &probe);
EMIT6_DISP_LH(0xe3000000, 0x0024, REG_W0, dst_reg, REG_0, off); /* stg %w0,off(%dst,%arena) */
EMIT6_DISP_LH(0xe3000000, 0x0024, REG_W0, dst_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0)
return err;
jit->seen |= SEEN_MEM; jit->seen |= SEEN_MEM;
break; break;
/* /*
...@@ -1506,9 +1576,11 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, ...@@ -1506,9 +1576,11 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
*/ */
case BPF_LDX | BPF_MEM | BPF_B: /* dst = *(u8 *)(ul) (src + off) */ case BPF_LDX | BPF_MEM | BPF_B: /* dst = *(u8 *)(ul) (src + off) */
case BPF_LDX | BPF_PROBE_MEM | BPF_B: case BPF_LDX | BPF_PROBE_MEM | BPF_B:
case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
bpf_jit_probe_load_pre(jit, insn, &probe); bpf_jit_probe_load_pre(jit, insn, &probe);
/* llgc %dst,0(off,%src) */ /* llgc %dst,off(%src,%arena) */
EMIT6_DISP_LH(0xe3000000, 0x0090, dst_reg, src_reg, REG_0, off); EMIT6_DISP_LH(0xe3000000, 0x0090, dst_reg, src_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe); err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0) if (err < 0)
return err; return err;
...@@ -1519,7 +1591,7 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, ...@@ -1519,7 +1591,7 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
case BPF_LDX | BPF_MEMSX | BPF_B: /* dst = *(s8 *)(ul) (src + off) */ case BPF_LDX | BPF_MEMSX | BPF_B: /* dst = *(s8 *)(ul) (src + off) */
case BPF_LDX | BPF_PROBE_MEMSX | BPF_B: case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
bpf_jit_probe_load_pre(jit, insn, &probe); bpf_jit_probe_load_pre(jit, insn, &probe);
/* lgb %dst,0(off,%src) */ /* lgb %dst,off(%src) */
EMIT6_DISP_LH(0xe3000000, 0x0077, dst_reg, src_reg, REG_0, off); EMIT6_DISP_LH(0xe3000000, 0x0077, dst_reg, src_reg, REG_0, off);
err = bpf_jit_probe_post(jit, fp, &probe); err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0) if (err < 0)
...@@ -1528,9 +1600,11 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, ...@@ -1528,9 +1600,11 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
break; break;
case BPF_LDX | BPF_MEM | BPF_H: /* dst = *(u16 *)(ul) (src + off) */ case BPF_LDX | BPF_MEM | BPF_H: /* dst = *(u16 *)(ul) (src + off) */
case BPF_LDX | BPF_PROBE_MEM | BPF_H: case BPF_LDX | BPF_PROBE_MEM | BPF_H:
case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
bpf_jit_probe_load_pre(jit, insn, &probe); bpf_jit_probe_load_pre(jit, insn, &probe);
/* llgh %dst,0(off,%src) */ /* llgh %dst,off(%src,%arena) */
EMIT6_DISP_LH(0xe3000000, 0x0091, dst_reg, src_reg, REG_0, off); EMIT6_DISP_LH(0xe3000000, 0x0091, dst_reg, src_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe); err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0) if (err < 0)
return err; return err;
...@@ -1541,7 +1615,7 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, ...@@ -1541,7 +1615,7 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
case BPF_LDX | BPF_MEMSX | BPF_H: /* dst = *(s16 *)(ul) (src + off) */ case BPF_LDX | BPF_MEMSX | BPF_H: /* dst = *(s16 *)(ul) (src + off) */
case BPF_LDX | BPF_PROBE_MEMSX | BPF_H: case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
bpf_jit_probe_load_pre(jit, insn, &probe); bpf_jit_probe_load_pre(jit, insn, &probe);
/* lgh %dst,0(off,%src) */ /* lgh %dst,off(%src) */
EMIT6_DISP_LH(0xe3000000, 0x0015, dst_reg, src_reg, REG_0, off); EMIT6_DISP_LH(0xe3000000, 0x0015, dst_reg, src_reg, REG_0, off);
err = bpf_jit_probe_post(jit, fp, &probe); err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0) if (err < 0)
...@@ -1550,10 +1624,12 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, ...@@ -1550,10 +1624,12 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
break; break;
case BPF_LDX | BPF_MEM | BPF_W: /* dst = *(u32 *)(ul) (src + off) */ case BPF_LDX | BPF_MEM | BPF_W: /* dst = *(u32 *)(ul) (src + off) */
case BPF_LDX | BPF_PROBE_MEM | BPF_W: case BPF_LDX | BPF_PROBE_MEM | BPF_W:
case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
bpf_jit_probe_load_pre(jit, insn, &probe); bpf_jit_probe_load_pre(jit, insn, &probe);
/* llgf %dst,off(%src) */ /* llgf %dst,off(%src) */
jit->seen |= SEEN_MEM; jit->seen |= SEEN_MEM;
EMIT6_DISP_LH(0xe3000000, 0x0016, dst_reg, src_reg, REG_0, off); EMIT6_DISP_LH(0xe3000000, 0x0016, dst_reg, src_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe); err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0) if (err < 0)
return err; return err;
...@@ -1572,10 +1648,12 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, ...@@ -1572,10 +1648,12 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
break; break;
case BPF_LDX | BPF_MEM | BPF_DW: /* dst = *(u64 *)(ul) (src + off) */ case BPF_LDX | BPF_MEM | BPF_DW: /* dst = *(u64 *)(ul) (src + off) */
case BPF_LDX | BPF_PROBE_MEM | BPF_DW: case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
bpf_jit_probe_load_pre(jit, insn, &probe); bpf_jit_probe_load_pre(jit, insn, &probe);
/* lg %dst,0(off,%src) */ /* lg %dst,off(%src,%arena) */
jit->seen |= SEEN_MEM; jit->seen |= SEEN_MEM;
EMIT6_DISP_LH(0xe3000000, 0x0004, dst_reg, src_reg, REG_0, off); EMIT6_DISP_LH(0xe3000000, 0x0004, dst_reg, src_reg,
probe.arena_reg, off);
err = bpf_jit_probe_post(jit, fp, &probe); err = bpf_jit_probe_post(jit, fp, &probe);
if (err < 0) if (err < 0)
return err; return err;
...@@ -1988,12 +2066,17 @@ static int bpf_jit_prog(struct bpf_jit *jit, struct bpf_prog *fp, ...@@ -1988,12 +2066,17 @@ static int bpf_jit_prog(struct bpf_jit *jit, struct bpf_prog *fp,
bool extra_pass, u32 stack_depth) bool extra_pass, u32 stack_depth)
{ {
int i, insn_count, lit32_size, lit64_size; int i, insn_count, lit32_size, lit64_size;
u64 kern_arena;
jit->lit32 = jit->lit32_start; jit->lit32 = jit->lit32_start;
jit->lit64 = jit->lit64_start; jit->lit64 = jit->lit64_start;
jit->prg = 0; jit->prg = 0;
jit->excnt = 0; jit->excnt = 0;
kern_arena = bpf_arena_get_kern_vm_start(fp->aux->arena);
if (kern_arena)
jit->kern_arena = _EMIT_CONST_U64(kern_arena);
bpf_jit_prologue(jit, fp, stack_depth); bpf_jit_prologue(jit, fp, stack_depth);
if (bpf_set_addr(jit, 0) < 0) if (bpf_set_addr(jit, 0) < 0)
return -1; return -1;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment