Commit d1791f0b authored by Marek Vavruša's avatar Marek Vavruša Committed by yonghong-song

lua: codegen tests, tracking inverse form of null guards, IPv6 header builtins (#1991)

* lua: eliding repetitive NULL checks, fixes for read types, luacheck

This fixes verifier failures in some kernel versions that track whether
a possible NULL pointer has been check (and shouldn't be checked twice).

It also fixes accessing memory from maps with composite keys.

It fixes a case when temporary variable is used in condition,
and the body of the condition contains just an assignment from
immediate value to a variable that existed before the condition, e.g.:

```
local x = 1 -- Assign constant to 'x', it will be materialized here
if skb.len > 0 then -- New BB, not possible to fold as condition isn't const
    x = 2 -- Assign constant, it failed to materialize here
end -- End of BB
-- Value of x is always 1
```

The `bpf.socket()` support ljsyscall socket as first parameter, and
fixes compatibility with newer ljsyscall (type name has changed).

* reverse BPF_LD ntohl() semantics for <= 32bit loads as well

The loads using traditional instructions always do ntohl():

https://www.kernel.org/doc/Documentation/networking/filter.txt

They are however needed to support reads not using `skb`, e.g. NET_OFF

* proto: add builtins for traversing IPv6 header and extension headers

The IPv6 header is fixed size (40B), but it may contain either
extension header, or transport  protocol after it, so the caller
must check the `next_header` field before traversing to determine
which dissector to use to traverse it, e.g.:

```
if ip6.next_header == 44 then
   ip6 = ip6.ip6_opt -- Skip fragment header
end
if ip6.next_header == c.ip.proto_tcp then
   local tcp = ip6.tcp -- Finally, TCP
end
```

* reverse ntohl() for indirect BPF_LD as well

* lua: started codegen tests, direct skb support, bugfixes

This starts adding compiler correctness tests for basic expressions,
variable source tracking, and control flow.

Added:
* direct skb->data access (in addition to BPF_LDABS, BPF_LDIND)
* loads and stores from skb->data and map value backed memory
* unified variable source tracking (ptr_to_ctx, ptr_to_map[_or_null], ptr_to_skb, ptr_to_pkt, ptr_to_stack)
* BPF constants introduced between 4.10-4.15
* bpf.dump_string() to dump generated assembly (text) to string

Fixes:
* pointer nil check tracking
* dissectors for map value backed memory
* ljsyscall extensions when the version is too old
* KPRI nil variables used in conditions
* wrongly elided constant materialization on condition-less jumps
* loads/stores from stack memory using variable offset

* lua: track inverse null guards (if x ~= nil)

The verifier prohibits pointer comparisons except the first NULL check,
so both forms must be tracked, otherwise it will check for NULL twice
and verifier will reject it.

* lua: support cdata and constants larger than i32, fix shadowed variables

This adds support for numeric cdata constants (either in program, or
retrieved with GGET/UGET). Values larger than i32 are coerced into i64.

This also fixes shadowing of variables, and fixes materialization of
result of variable copy at the end of basic block.
parent 9252a8d0
-- Configuration for unit tests
-- See: http://olivinelabs.com/busted/
return {
default = {
lpath = "./?.lua;./?/init.lua",
helper = "./bpf/spec/helper.lua",
["auto-insulate"] = false,
}
}
std = 'luajit'
new_read_globals = {
'assert',
'describe',
'it',
}
new_globals = {
'math',
}
-- Luacheck < 0.18 doesn't support new_read_globals
for _, v in ipairs(new_read_globals) do
table.insert(new_globals, v)
end
-- Ignore some pedantic checks
ignore = {
'4.1/err', -- Shadowing err
'4.1/.', -- Shadowing one letter variables
}
...@@ -55,6 +55,7 @@ local const_expr = { ...@@ -55,6 +55,7 @@ local const_expr = {
JGE = function (a, b) return a >= b end, JGE = function (a, b) return a >= b end,
JGT = function (a, b) return a > b end, JGT = function (a, b) return a > b end,
} }
local const_width = { local const_width = {
[1] = BPF.B, [2] = BPF.H, [4] = BPF.W, [8] = BPF.DW, [1] = BPF.B, [2] = BPF.H, [4] = BPF.W, [8] = BPF.DW,
} }
...@@ -65,19 +66,16 @@ local builtins_strict = { ...@@ -65,19 +66,16 @@ local builtins_strict = {
[print] = true, [print] = true,
} }
-- Return struct member size/type (requires LuaJIT 2.1+) -- Deep copy a table
-- I am ashamed that there's no easier way around it. local function table_copy(t)
local function sizeofattr(ct, name) local copy = {}
if not ffi.typeinfo then error('LuaJIT 2.1+ is required for ffi.typeinfo') end for n,v in pairs(t) do
local cinfo = ffi.typeinfo(ct) if type(v) == 'table' then
while true do v = table_copy(v)
cinfo = ffi.typeinfo(cinfo.sib) end
if not cinfo then return end copy[n] = v
if cinfo.name == name then break end
end end
local size = math.max(1, ffi.typeinfo(cinfo.sib or ct).size - cinfo.size) return copy
-- Guess type name
return size, builtins.width_type(size)
end end
-- Return true if the constant part is a proxy -- Return true if the constant part is a proxy
...@@ -117,6 +115,7 @@ end ...@@ -117,6 +115,7 @@ end
local function reg_spill(var) local function reg_spill(var)
local vinfo = V[var] local vinfo = V[var]
assert(vinfo.reg, 'attempt to spill VAR that doesn\'t have an allocated register')
vinfo.spill = (var + 1) * ffi.sizeof('uint64_t') -- Index by (variable number) * (register width) vinfo.spill = (var + 1) * ffi.sizeof('uint64_t') -- Index by (variable number) * (register width)
emit(BPF.MEM + BPF.STX + BPF.DW, 10, vinfo.reg, -vinfo.spill, 0) emit(BPF.MEM + BPF.STX + BPF.DW, 10, vinfo.reg, -vinfo.spill, 0)
vinfo.reg = nil vinfo.reg = nil
...@@ -124,6 +123,7 @@ end ...@@ -124,6 +123,7 @@ end
local function reg_fill(var, reg) local function reg_fill(var, reg)
local vinfo = V[var] local vinfo = V[var]
assert(reg, 'attempt to fill variable to register but not register is allocated')
assert(vinfo.spill, 'attempt to fill register with a VAR that isn\'t spilled') assert(vinfo.spill, 'attempt to fill register with a VAR that isn\'t spilled')
emit(BPF.MEM + BPF.LDX + BPF.DW, reg, 10, -vinfo.spill, 0) emit(BPF.MEM + BPF.LDX + BPF.DW, reg, 10, -vinfo.spill, 0)
vinfo.reg = reg vinfo.reg = reg
...@@ -182,8 +182,14 @@ local function vset(var, reg, const, vtype) ...@@ -182,8 +182,14 @@ local function vset(var, reg, const, vtype)
end end
end end
-- Get precise type for CDATA or attempt to narrow numeric constant -- Get precise type for CDATA or attempt to narrow numeric constant
if not vtype and type(const) == 'cdata' then vtype = ffi.typeof(const) end if not vtype and type(const) == 'cdata' then
vtype = ffi.typeof(const)
end
V[var] = {reg=reg, const=const, type=vtype} V[var] = {reg=reg, const=const, type=vtype}
-- Track variable source
if V[var].const and type(const) == 'table' then
V[var].source = V[var].const.source
end
end end
-- Materialize (or register) a variable in a register -- Materialize (or register) a variable in a register
...@@ -197,7 +203,7 @@ local function vreg(var, reg, reserve, vtype) ...@@ -197,7 +203,7 @@ local function vreg(var, reg, reserve, vtype)
-- Materialize variable shadow copy -- Materialize variable shadow copy
local src = vinfo local src = vinfo
while src.shadow do src = V[src.shadow] end while src.shadow do src = V[src.shadow] end
if reserve then if reserve then -- luacheck: ignore
-- No load to register occurs -- No load to register occurs
elseif src.reg then elseif src.reg then
emit(BPF.ALU64 + BPF.MOV + BPF.X, reg, src.reg, 0, 0) emit(BPF.ALU64 + BPF.MOV + BPF.X, reg, src.reg, 0, 0)
...@@ -237,15 +243,20 @@ local function vcopy(dst, src) ...@@ -237,15 +243,20 @@ local function vcopy(dst, src)
end end
-- Dereference variable of pointer type -- Dereference variable of pointer type
local function vderef(dst_reg, src_reg, vtype) local function vderef(dst_reg, src_reg, vinfo)
-- Dereference map pointers for primitive types -- Dereference map pointers for primitive types
-- BPF doesn't allow pointer arithmetics, so use the entry value -- BPF doesn't allow pointer arithmetics, so use the entry value
assert(type(vinfo.const) == 'table' and vinfo.const.__dissector, 'cannot dereference a non-pointer variable')
local vtype = vinfo.const.__dissector
local w = ffi.sizeof(vtype) local w = ffi.sizeof(vtype)
assert(const_width[w], 'NYI: sizeof('..tostring(vtype)..') not 1/2/4/8 bytes') assert(const_width[w], 'NYI: sizeof('..tostring(vtype)..') not 1/2/4/8 bytes')
if dst_reg ~= src_reg then if dst_reg ~= src_reg then
emit(BPF.ALU64 + BPF.MOV + BPF.X, dst_reg, src_reg, 0, 0) -- dst = src emit(BPF.ALU64 + BPF.MOV + BPF.X, dst_reg, src_reg, 0, 0) -- dst = src
end end
-- Optimize the NULL check away if provably not NULL
if not vinfo.source or vinfo.source:find('_or_null', 1, true) then
emit(BPF.JMP + BPF.JEQ + BPF.K, src_reg, 0, 1, 0) -- if (src != NULL) emit(BPF.JMP + BPF.JEQ + BPF.K, src_reg, 0, 1, 0) -- if (src != NULL)
end
emit(BPF.MEM + BPF.LDX + const_width[w], dst_reg, src_reg, 0, 0) -- dst = *src; emit(BPF.MEM + BPF.LDX + const_width[w], dst_reg, src_reg, 0, 0) -- dst = *src;
end end
...@@ -280,12 +291,44 @@ local function valloc(size, blank) ...@@ -280,12 +291,44 @@ local function valloc(size, blank)
return stack_top return stack_top
end end
-- Turn variable into scalar in register (or constant)
local function vscalar(a, w)
assert(const_width[w], 'sizeof(scalar variable) must be 1/2/4/8')
local src_reg
-- If source is a pointer, we must dereference it first
if cdef.isptr(V[a].type) then
src_reg = vreg(a)
local tmp_reg = reg_alloc(stackslots, 1) -- Clone variable in tmp register
emit(BPF.ALU64 + BPF.MOV + BPF.X, tmp_reg, src_reg, 0, 0)
vderef(tmp_reg, tmp_reg, V[a])
src_reg = tmp_reg -- Materialize and dereference it
-- Source is a value on stack, we must load it first
elseif type(V[a].const) == 'table' and V[a].const.__base > 0 then
src_reg = vreg(a)
emit(BPF.MEM + BPF.LDX + const_width[w], src_reg, 10, -V[a].const.__base, 0)
V[a].type = V[a].const.__dissector
V[a].const = nil -- Value is dereferenced
-- If source is an imm32 number, avoid register load
elseif type(V[a].const) == 'number' and w < 8 then
return nil, V[a].const
-- Load variable from any other source
else
src_reg = vreg(a)
end
return src_reg, nil
end
-- Emit compensation code at the end of basic block to unify variable set layout on all block exits -- Emit compensation code at the end of basic block to unify variable set layout on all block exits
-- 1. we need to free registers by spilling -- 1. we need to free registers by spilling
-- 2. fill registers to match other exits from this BB -- 2. fill registers to match other exits from this BB
local function bb_end(Vcomp) local function bb_end(Vcomp)
for i,v in pairs(V) do for i,v in pairs(V) do
if Vcomp[i] and Vcomp[i].spill and not v.spill then if Vcomp[i] and Vcomp[i].spill and not v.spill then
-- Materialize constant or shadowing variable to be able to spill
if not v.reg and (v.shadow or cdef.isimmconst(v)) then
vreg(i)
end
reg_spill(i) reg_spill(i)
end end
end end
...@@ -293,6 +336,10 @@ local function bb_end(Vcomp) ...@@ -293,6 +336,10 @@ local function bb_end(Vcomp)
if Vcomp[i] and Vcomp[i].reg and not v.reg then if Vcomp[i] and Vcomp[i].reg and not v.reg then
vreg(i, Vcomp[i].reg) vreg(i, Vcomp[i].reg)
end end
-- Compensate variable metadata change
if Vcomp[i] and Vcomp[i].source then
V[i].source = Vcomp[i].source
end
end end
end end
...@@ -334,16 +381,15 @@ local function CMP_REG(a, b, op) ...@@ -334,16 +381,15 @@ local function CMP_REG(a, b, op)
-- compiler to replace it's absolute offset to LJ bytecode insn with a relative -- compiler to replace it's absolute offset to LJ bytecode insn with a relative
-- offset in BPF program code, verifier will accept only programs with valid JMP targets -- offset in BPF program code, verifier will accept only programs with valid JMP targets
local a_reg, b_reg = vreg(a), vreg(b) local a_reg, b_reg = vreg(a), vreg(b)
-- Migrate operands from R0-5 as it will be spilled in compensation code when JMP out of BB
if a_reg == 0 then a_reg = vreg(a, 7) end
emit(BPF.JMP + BPF[op] + BPF.X, a_reg, b_reg, 0xffff, 0) emit(BPF.JMP + BPF[op] + BPF.X, a_reg, b_reg, 0xffff, 0)
code.seen_cmp = code.pc-1 code.seen_cmp = code.pc-1
end end
end end
local function CMP_IMM(a, b, op) local function CMP_IMM(a, b, op)
if V[a].const and not is_proxy(V[a].const) then -- Fold compile-time expressions local c = V[a].const
code.seen_cmp = const_expr[op](V[a].const, b) and ALWAYS or NEVER if c and not is_proxy(c) then -- Fold compile-time expressions
code.seen_cmp = const_expr[op](c, b) and ALWAYS or NEVER
else else
-- Convert imm32 to number -- Convert imm32 to number
if type(b) == 'string' then if type(b) == 'string' then
...@@ -362,24 +408,38 @@ local function CMP_IMM(a, b, op) ...@@ -362,24 +408,38 @@ local function CMP_IMM(a, b, op)
-- compiler to replace it's absolute offset to LJ bytecode insn with a relative -- compiler to replace it's absolute offset to LJ bytecode insn with a relative
-- offset in BPF program code, verifier will accept only programs with valid JMP targets -- offset in BPF program code, verifier will accept only programs with valid JMP targets
local reg = vreg(a) local reg = vreg(a)
-- Migrate operands from R0-5 as it will be spilled in compensation code when JMP out of BB
if reg == 0 then reg = vreg(a, 7) end
emit(BPF.JMP + BPF[op] + BPF.K, reg, 0, 0xffff, b) emit(BPF.JMP + BPF[op] + BPF.K, reg, 0, 0xffff, b)
code.seen_cmp = code.pc-1 code.seen_cmp = code.pc-1
-- Remember NULL pointer checks as BPF prohibits pointer comparisons
-- and repeated checks wouldn't pass the verifier, only comparisons
-- against constants are checked.
if op == 'JEQ' and tonumber(b) == 0 and V[a].source then
local pos = V[a].source:find('_or_null', 1, true)
if pos then
code.seen_null_guard = a
end
-- Inverse NULL pointer check (if a ~= nil)
elseif op == 'JNE' and tonumber(b) == 0 and V[a].source then
local pos = V[a].source:find('_or_null', 1, true)
if pos then
code.seen_null_guard = a
code.seen_null_guard_inverse = true
end
end
end end
end end
local function ALU_IMM(dst, a, b, op) local function ALU_IMM(dst, a, b, op)
-- Fold compile-time expressions -- Fold compile-time expressions
if V[a].const and not is_proxy(V[a].const) then if V[a].const and not is_proxy(V[a].const) then
assert(type(V[a].const) == 'number', 'VAR '..a..' must be numeric') assert(cdef.isimmconst(V[a]), 'VAR '..a..' must be numeric')
vset(dst, nil, const_expr[op](V[a].const, b)) vset(dst, nil, const_expr[op](V[a].const, b))
-- Now we need to materialize dissected value at DST, and add it -- Now we need to materialize dissected value at DST, and add it
else else
vcopy(dst, a) vcopy(dst, a)
local dst_reg = vreg(dst) local dst_reg = vreg(dst)
if cdef.isptr(V[a].type) then if cdef.isptr(V[a].type) then
vderef(dst_reg, dst_reg, V[a].const.__dissector) vderef(dst_reg, dst_reg, V[a])
V[dst].type = V[a].const.__dissector V[dst].type = V[a].const.__dissector
else else
V[dst].type = V[a].type V[dst].type = V[a].type
...@@ -391,8 +451,8 @@ end ...@@ -391,8 +451,8 @@ end
local function ALU_REG(dst, a, b, op) local function ALU_REG(dst, a, b, op)
-- Fold compile-time expressions -- Fold compile-time expressions
if V[a].const and not (is_proxy(V[a].const) or is_proxy(V[b].const)) then if V[a].const and not (is_proxy(V[a].const) or is_proxy(V[b].const)) then
assert(type(V[a].const) == 'number', 'VAR '..a..' must be numeric') assert(cdef.isimmconst(V[a]), 'VAR '..a..' must be numeric')
assert(type(V[b].const) == 'number', 'VAR '..b..' must be numeric') assert(cdef.isimmconst(V[b]), 'VAR '..b..' must be numeric')
if type(op) == 'string' then op = const_expr[op] end if type(op) == 'string' then op = const_expr[op] end
vcopy(dst, a) vcopy(dst, a)
V[dst].const = op(V[a].const, V[b].const) V[dst].const = op(V[a].const, V[b].const)
...@@ -402,13 +462,13 @@ local function ALU_REG(dst, a, b, op) ...@@ -402,13 +462,13 @@ local function ALU_REG(dst, a, b, op)
-- We have to allocate a temporary register for dereferencing to preserve -- We have to allocate a temporary register for dereferencing to preserve
-- pointer in source variable that MUST NOT be altered -- pointer in source variable that MUST NOT be altered
reg_alloc(stackslots, 2) reg_alloc(stackslots, 2)
vderef(2, src_reg, V[b].const.__dissector) vderef(2, src_reg, V[b])
src_reg = 2 src_reg = 2
end end
vcopy(dst, a) -- DST may alias B, so copy must occur after we materialize B vcopy(dst, a) -- DST may alias B, so copy must occur after we materialize B
local dst_reg = vreg(dst) local dst_reg = vreg(dst)
if cdef.isptr(V[a].type) then if cdef.isptr(V[a].type) then
vderef(dst_reg, dst_reg, V[a].const.__dissector) vderef(dst_reg, dst_reg, V[a])
V[dst].type = V[a].const.__dissector V[dst].type = V[a].const.__dissector
end end
emit(BPF.ALU64 + BPF[op] + BPF.X, dst_reg, src_reg, 0, 0) emit(BPF.ALU64 + BPF[op] + BPF.X, dst_reg, src_reg, 0, 0)
...@@ -424,10 +484,14 @@ local function ALU_IMM_NV(dst, a, b, op) ...@@ -424,10 +484,14 @@ local function ALU_IMM_NV(dst, a, b, op)
ALU_REG(dst, stackslots+1, b, op) ALU_REG(dst, stackslots+1, b, op)
end end
local function LD_ABS(dst, off, w) local function LD_ABS(dst, w, off)
assert(off, 'LD_ABS called without offset')
if w < 8 then if w < 8 then
local dst_reg = vreg(dst, 0, true, builtins.width_type(w)) -- Reserve R0 local dst_reg = vreg(dst, 0, true, builtins.width_type(w)) -- Reserve R0
emit(BPF.LD + BPF.ABS + const_width[w], dst_reg, 0, 0, off) emit(BPF.LD + BPF.ABS + const_width[w], dst_reg, 0, 0, off)
if w > 1 and ffi.abi('le') then -- LD_ABS has htonl() semantics, reverse
emit(BPF.ALU + BPF.END + BPF.TO_BE, dst_reg, 0, 0, w * 8)
end
elseif w == 8 then elseif w == 8 then
-- LD_ABS|IND prohibits DW, we need to do two W loads and combine them -- LD_ABS|IND prohibits DW, we need to do two W loads and combine them
local tmp_reg = vreg(stackslots, 0, true, builtins.width_type(w)) -- Reserve R0 local tmp_reg = vreg(stackslots, 0, true, builtins.width_type(w)) -- Reserve R0
...@@ -452,14 +516,15 @@ local function LD_IND(dst, src, w, off) ...@@ -452,14 +516,15 @@ local function LD_IND(dst, src, w, off)
local src_reg = vreg(src) -- Must materialize first in case dst == src local src_reg = vreg(src) -- Must materialize first in case dst == src
local dst_reg = vreg(dst, 0, true, builtins.width_type(w)) -- Reserve R0 local dst_reg = vreg(dst, 0, true, builtins.width_type(w)) -- Reserve R0
emit(BPF.LD + BPF.IND + const_width[w], dst_reg, src_reg, 0, off or 0) emit(BPF.LD + BPF.IND + const_width[w], dst_reg, src_reg, 0, off or 0)
if w > 1 and ffi.abi('le') then -- LD_ABS has htonl() semantics, reverse
emit(BPF.ALU + BPF.END + BPF.TO_BE, dst_reg, 0, 0, w * 8)
end
end end
local function LD_FIELD(a, d, w, imm) local function LD_MEM(dst, src, w, off)
if imm then local src_reg = vreg(src) -- Must materialize first in case dst == src
LD_ABS(a, imm, w) local dst_reg = vreg(dst, nil, true, builtins.width_type(w)) -- Reserve R0
else emit(BPF.MEM + BPF.LDX + const_width[w], dst_reg, src_reg, off or 0, 0)
LD_IND(a, d, w)
end
end end
-- @note: This is specific now as it expects registers reserved -- @note: This is specific now as it expects registers reserved
...@@ -473,33 +538,52 @@ local function LD_IMM_X(dst_reg, src_type, imm, w) ...@@ -473,33 +538,52 @@ local function LD_IMM_X(dst_reg, src_type, imm, w)
end end
end end
local function BUILTIN(func, ...)
local builtin_export = {
-- Compiler primitives (work with variable slots, emit instructions)
V=V, vreg=vreg, vset=vset, vcopy=vcopy, vderef=vderef, valloc=valloc, emit=emit,
reg_alloc=reg_alloc, reg_spill=reg_spill, tmpvar=stackslots, const_width=const_width,
-- Extensions and helpers (use with care)
LD_IMM_X = LD_IMM_X,
}
func(builtin_export, ...)
end
local function LOAD(dst, src, off, vtype) local function LOAD(dst, src, off, vtype)
local base = V[src].const local base = V[src].const
assert(base.__dissector, "NYI: load() on variable that doesn't have dissector") assert(base and base.__dissector, 'NYI: load() on variable that doesn\'t have dissector')
assert(V[src].source, 'NYI: load() on variable with unknown source')
-- Cast to different type if requested -- Cast to different type if requested
vtype = vtype or base.__dissector vtype = vtype or base.__dissector
local w = ffi.sizeof(vtype) local w = ffi.sizeof(vtype)
assert(w <= 4, 'NYI: load() supports 1/2/4 bytes at a time only') assert(const_width[w], 'NYI: load() supports 1/2/4/8 bytes at a time only, wanted ' .. tostring(w))
-- Packet access with a dissector (use BPF_LD)
if V[src].source:find('ptr_to_pkt', 1, true) then
if base.off then -- Absolute address to payload if base.off then -- Absolute address to payload
LD_ABS(dst, off + base.off, w) LD_ABS(dst, w, off + base.off)
else -- Indirect address to payload else -- Indirect address to payload
LD_IND(dst, src, w, off) LD_IND(dst, src, w, off)
end end
-- Direct access to first argument (skb fields, pt regs, ...)
elseif V[src].source:find('ptr_to_ctx', 1, true) then
LD_MEM(dst, src, w, off)
-- Direct skb access with a dissector (use BPF_MEM)
elseif V[src].source:find('ptr_to_skb', 1, true) then
LD_MEM(dst, src, w, off)
-- Pointer to map-backed memory (use BPF_MEM)
elseif V[src].source:find('ptr_to_map_value', 1, true) then
LD_MEM(dst, src, w, off)
-- Indirect read using probe (uprobe or kprobe, uses helper)
elseif V[src].source:find('ptr_to_probe', 1, true) then
BUILTIN(builtins[builtins.probe_read], nil, dst, src, vtype, off)
V[dst].source = V[src].source -- Builtin handles everything
else
error('NYI: load() on variable from ' .. V[src].source)
end
V[dst].type = vtype V[dst].type = vtype
V[dst].const = nil -- Dissected value is not constant anymore V[dst].const = nil -- Dissected value is not constant anymore
end end
local function BUILTIN(func, ...)
local builtin_export = {
-- Compiler primitives (work with variable slots, emit instructions)
V=V, vreg=vreg, vset=vset, vcopy=vcopy, vderef=vderef, valloc=valloc, emit=emit,
reg_alloc=reg_alloc, reg_spill=reg_spill, tmpvar=stackslots, const_width=const_width,
-- Extensions and helpers (use with care)
LD_IMM_X = LD_IMM_X,
}
func(builtin_export, ...)
end
local function CALL(a, b, d) local function CALL(a, b, d)
assert(b-1 <= 1, 'NYI: CALL with >1 return values') assert(b-1 <= 1, 'NYI: CALL with >1 return values')
-- Perform either compile-time, helper, or builtin -- Perform either compile-time, helper, or builtin
...@@ -529,7 +613,12 @@ local function CALL(a, b, d) ...@@ -529,7 +613,12 @@ local function CALL(a, b, d)
assert(V[a+1].const and V[a+2].const, 'NYI: slice() arguments must be constant') assert(V[a+1].const and V[a+2].const, 'NYI: slice() arguments must be constant')
local off = V[a+1].const local off = V[a+1].const
local vtype = builtins.width_type(V[a+2].const - off) local vtype = builtins.width_type(V[a+2].const - off)
-- Access to packet via packet (use BPF_LD)
if V[a].source and V[a].source:find('ptr_to_', 1, true) then
LOAD(a, a, off, vtype) LOAD(a, a, off, vtype)
else
error('NYI: <dissector>.slice(a, b) on non-pointer memory ' .. (V[a].source or 'unknown'))
end
-- Strict builtins cannot be expanded on compile-time -- Strict builtins cannot be expanded on compile-time
elseif builtins_strict[func] and builtin then elseif builtins_strict[func] and builtin then
args = {a} args = {a}
...@@ -537,6 +626,7 @@ local function CALL(a, b, d) ...@@ -537,6 +626,7 @@ local function CALL(a, b, d)
BUILTIN(builtin, unpack(args)) BUILTIN(builtin, unpack(args))
-- Attempt compile-time call expansion (expects all argument compile-time known) -- Attempt compile-time call expansion (expects all argument compile-time known)
else else
assert(const, 'NYI: CALL attempted on constant arguments, but at least one argument is not constant')
V[a].const = func(unpack(args)) V[a].const = func(unpack(args))
end end
end end
...@@ -591,6 +681,7 @@ local function MAP_GET(dst, map_var, key, imm) ...@@ -591,6 +681,7 @@ local function MAP_GET(dst, map_var, key, imm)
-- Flag as pointer type and associate dissector for map value type -- Flag as pointer type and associate dissector for map value type
vreg(dst, 0, true, ffi.typeof('uint8_t *')) vreg(dst, 0, true, ffi.typeof('uint8_t *'))
V[dst].const = {__dissector=map.val_type} V[dst].const = {__dissector=map.val_type}
V[dst].source = 'ptr_to_map_value_or_null'
emit(BPF.JMP + BPF.CALL, 0, 0, 0, HELPER.map_lookup_elem) emit(BPF.JMP + BPF.CALL, 0, 0, 0, HELPER.map_lookup_elem)
V[stackslots].reg = nil -- Free temporary registers V[stackslots].reg = nil -- Free temporary registers
end end
...@@ -630,7 +721,7 @@ local function MAP_SET(map_var, key, key_imm, src) ...@@ -630,7 +721,7 @@ local function MAP_SET(map_var, key, key_imm, src)
elseif V[src].reg and pod_type then elseif V[src].reg and pod_type then
-- Value is a pointer, derefernce it and spill it -- Value is a pointer, derefernce it and spill it
if cdef.isptr(V[src].type) then if cdef.isptr(V[src].type) then
vderef(3, V[src].reg, V[src].const.__dissector) vderef(3, V[src].reg, V[src])
emit(BPF.MEM + BPF.STX + w, 10, 3, -sp, 0) emit(BPF.MEM + BPF.STX + w, 10, 3, -sp, 0)
else else
emit(BPF.MEM + BPF.STX + w, 10, V[src].reg, -sp, 0) emit(BPF.MEM + BPF.STX + w, 10, V[src].reg, -sp, 0)
...@@ -645,18 +736,22 @@ local function MAP_SET(map_var, key, key_imm, src) ...@@ -645,18 +736,22 @@ local function MAP_SET(map_var, key, key_imm, src)
emit(BPF.JMP + BPF.CALL, 0, 0, 0, HELPER.map_update_elem) emit(BPF.JMP + BPF.CALL, 0, 0, 0, HELPER.map_update_elem)
return return
end end
vderef(3, V[src].reg, V[src].const.__dissector) vderef(3, V[src].reg, V[src])
emit(BPF.MEM + BPF.STX + w, 10, 3, -sp, 0) emit(BPF.MEM + BPF.STX + w, 10, 3, -sp, 0)
else else
sp = V[src].spill sp = V[src].spill
end end
-- Value is already on stack, write to base-relative address -- Value is already on stack, write to base-relative address
elseif base.__base then elseif base.__base then
assert(val_size == ffi.sizeof(V[key].type), 'VAR '..key..' type incompatible with BPF map value type') if val_size ~= ffi.sizeof(V[src].type) then
local err = string.format('VAR %d type (%s) incompatible with BPF map value type (%s): expected %d, got %d',
src, V[src].type, map.val_type, val_size, ffi.sizeof(V[src].type))
error(err)
end
sp = base.__base sp = base.__base
-- Value is constant, materialize it on stack -- Value is constant, materialize it on stack
else else
error('VAR '.. key or key_imm ..' is neither const-expr/register/stack/spilled') error('VAR '.. src ..' is neither const-expr/register/stack/spilled')
end end
emit(BPF.ALU64 + BPF.MOV + BPF.X, 3, 10, 0, 0) emit(BPF.ALU64 + BPF.MOV + BPF.X, 3, 10, 0, 0)
emit(BPF.ALU64 + BPF.ADD + BPF.K, 3, 0, 0, -sp) emit(BPF.ALU64 + BPF.ADD + BPF.K, 3, 0, 0, -sp)
...@@ -668,11 +763,27 @@ end ...@@ -668,11 +763,27 @@ end
local BC = { local BC = {
-- Constants -- Constants
KNUM = function(a, _, c, _) -- KNUM KNUM = function(a, _, c, _) -- KNUM
vset(a, nil, c, ffi.typeof('int32_t')) -- TODO: only 32bit immediates are supported now if c < 2147483648 then
vset(a, nil, c, ffi.typeof('int32_t'))
else
vset(a, nil, c, ffi.typeof('uint64_t'))
end
end, end,
KSHORT = function(a, _, _, d) -- KSHORT KSHORT = function(a, _, _, d) -- KSHORT
vset(a, nil, d, ffi.typeof('int16_t')) vset(a, nil, d, ffi.typeof('int16_t'))
end, end,
KCDATA = function(a, _, c, _) -- KCDATA
-- Coerce numeric types if possible
local ct = ffi.typeof(c)
if ffi.istype(ct, ffi.typeof('uint64_t')) or ffi.istype(ct, ffi.typeof('int64_t')) then
vset(a, nil, c, ct)
elseif tonumber(c) ~= nil then
-- TODO: this should not be possible
vset(a, nil, tonumber(c), ct)
else
error('NYI: cannot use CDATA constant of type ' .. ct)
end
end,
KPRI = function(a, _, _, d) -- KPRI KPRI = function(a, _, _, d) -- KPRI
-- KNIL is 0, must create a special type to identify it -- KNIL is 0, must create a special type to identify it
local vtype = (d < 1) and ffi.typeof('void') or ffi.typeof('uint8_t') local vtype = (d < 1) and ffi.typeof('void') or ffi.typeof('uint8_t')
...@@ -737,84 +848,172 @@ local BC = { ...@@ -737,84 +848,172 @@ local BC = {
vset(a, nil, env[c]) vset(a, nil, env[c])
else error(string.format("undefined upvalue '%s'", c)) end else error(string.format("undefined upvalue '%s'", c)) end
end, end,
TGETB = function (a, b, _, d) -- TGETB (A = B[D])
if a ~= b then vset(a) end
local base = V[b].const
if base.__map then -- BPF map read (constant)
MAP_GET(a, b, nil, d)
-- Specialise PTR[0] as dereference operator
elseif cdef.isptr(V[b].type) and d == 0 then
vcopy(a, b)
local dst_reg = vreg(a)
vderef(dst_reg, dst_reg, V[a].const.__dissector)
V[a].type = V[a].const.__dissector
else
LOAD(a, b, d, ffi.typeof('uint8_t'))
end
end,
TSETB = function (a, b, _, d) -- TSETB (B[D] = A) TSETB = function (a, b, _, d) -- TSETB (B[D] = A)
if V[b].const.__map then -- BPF map read (constant) assert(V[b] and type(V[b].const) == 'table', 'NYI: B[D] where B is not Lua table, BPF map, or pointer')
local vinfo = V[b].const
if vinfo.__map then -- BPF map read (constant)
return MAP_SET(b, nil, d, a) -- D is literal return MAP_SET(b, nil, d, a) -- D is literal
elseif V[b].const and V[b].const and V[a].const then elseif vinfo.__dissector then
V[b].const[V[d].const] = V[a].const assert(vinfo.__dissector, 'NYI: B[D] where B does not have a known element size')
else error('NYI: B[D] = A, where B is not Lua table or BPF map') local w = ffi.sizeof(vinfo.__dissector)
-- TODO: support vectorized moves larger than register width
assert(const_width[w], 'B[C] = A, sizeof(A) must be 1/2/4/8')
local src_reg, const = vscalar(a, w)
-- If changing map value, write to absolute address + offset
if V[b].source and V[b].source:find('ptr_to_map_value', 1, true) then
local dst_reg = vreg(b)
-- Optimization: immediate values (imm32) can be stored directly
if type(const) == 'number' then
emit(BPF.MEM + BPF.ST + const_width[w], dst_reg, 0, d, const)
else
emit(BPF.MEM + BPF.STX + const_width[w], dst_reg, src_reg, d, 0)
end
-- Table is already on stack, write to vinfo-relative address
elseif vinfo.__base then
-- Optimization: immediate values (imm32) can be stored directly
if type(const) == 'number' then
emit(BPF.MEM + BPF.ST + const_width[w], 10, 0, -vinfo.__base + (d * w), const)
else
emit(BPF.MEM + BPF.STX + const_width[w], 10, src_reg, -vinfo.__base + (d * w), 0)
end
else
error('NYI: B[D] where B is not Lua table, BPF map, or pointer')
end
elseif vinfo and vinfo and V[a].const then
vinfo[V[d].const] = V[a].const
else
error('NYI: B[D] where B is not Lua table, BPF map, or pointer')
end end
end, end,
TSETV = function (a, b, _, d) -- TSETV (B[D] = A) TSETV = function (a, b, _, d) -- TSETV (B[D] = A)
if V[b].const.__map then -- BPF map read (constant) assert(V[b] and type(V[b].const) == 'table', 'NYI: B[D] where B is not Lua table, BPF map, or pointer')
local vinfo = V[b].const
if vinfo.__map then -- BPF map read (constant)
return MAP_SET(b, d, nil, a) -- D is variable return MAP_SET(b, d, nil, a) -- D is variable
elseif V[b].const and V[d].const and V[a].const then elseif vinfo.__dissector then
V[b].const[V[d].const] = V[a].const assert(vinfo.__dissector, 'NYI: B[D] where B does not have a known element size')
else error('NYI: B[D] = A, where B is not Lua table or BPF map') local w = ffi.sizeof(vinfo.__dissector)
-- TODO: support vectorized moves larger than register width
assert(const_width[w], 'B[C] = A, sizeof(A) must be 1/2/4/8')
local src_reg, const = vscalar(a, w)
-- If changing map value, write to absolute address + offset
if V[b].source and V[b].source:find('ptr_to_map_value', 1, true) then
-- Calculate variable address from two registers
local tmp_var = stackslots + 1
vset(tmp_var, nil, d)
ALU_REG(tmp_var, tmp_var, b, 'ADD')
local dst_reg = vreg(tmp_var)
V[tmp_var].reg = nil -- Only temporary allocation
-- Optimization: immediate values (imm32) can be stored directly
if type(const) == 'number' and w < 8 then
emit(BPF.MEM + BPF.ST + const_width[w], dst_reg, 0, 0, const)
else
emit(BPF.MEM + BPF.STX + const_width[w], dst_reg, src_reg, 0, 0)
end
-- Table is already on stack, write to vinfo-relative address
elseif vinfo.__base then
-- Calculate variable address from two registers
local tmp_var = stackslots + 1
vcopy(tmp_var, d) -- Element position
if w > 1 then
ALU_IMM(tmp_var, tmp_var, w, 'MUL') -- multiply by element size
end
local dst_reg = vreg(tmp_var) -- add R10 (stack pointer)
emit(BPF.ALU64 + BPF.ADD + BPF.X, dst_reg, 10, 0, 0)
V[tmp_var].reg = nil -- Only temporary allocation
-- Optimization: immediate values (imm32) can be stored directly
if type(const) == 'number' and w < 8 then
emit(BPF.MEM + BPF.ST + const_width[w], dst_reg, 0, -vinfo.__base, const)
else
emit(BPF.MEM + BPF.STX + const_width[w], dst_reg, src_reg, -vinfo.__base, 0)
end
else
error('NYI: B[D] where B is not Lua table, BPF map, or pointer')
end
elseif vinfo and V[d].const and V[a].const then
vinfo[V[d].const] = V[a].const
else
error('NYI: B[D] where B is not Lua table, BPF map, or pointer')
end end
end, end,
TSETS = function (a, b, c, _) -- TSETS (B[C] = A) TSETS = function (a, b, c, _) -- TSETS (B[C] = A)
assert(V[b] and V[b].const, 'NYI: B[D] where B is not Lua table or BPF map') assert(V[b] and V[b].const, 'NYI: B[D] where B is not Lua table, BPF map, or pointer')
local base = V[b].const local base = V[b].const
if base.__dissector then if base.__dissector then
local ofs,bpos = ffi.offsetof(base.__dissector, c) local ofs,bpos = ffi.offsetof(base.__dissector, c)
assert(not bpos, 'NYI: B[C] = A, where C is a bitfield') assert(not bpos, 'NYI: B[C] = A, where C is a bitfield')
local w = sizeofattr(base.__dissector, c) local w = builtins.sizeofattr(base.__dissector, c)
-- TODO: support vectorized moves larger than register width -- TODO: support vectorized moves larger than register width
assert(const_width[w], 'B[C] = A, sizeof(A) must be 1/2/4/8') assert(const_width[w], 'B[C] = A, sizeof(A) must be 1/2/4/8')
local src_reg = vreg(a) local src_reg, const = vscalar(a, w)
-- If source is a pointer, we must dereference it first -- If changing map value, write to absolute address + offset
if cdef.isptr(V[a].type) then if V[b].source and V[b].source:find('ptr_to_map_value', 1, true) then
local tmp_reg = reg_alloc(stackslots, 1) -- Clone variable in tmp register local dst_reg = vreg(b)
emit(BPF.ALU64 + BPF.MOV + BPF.X, tmp_reg, src_reg, 0, 0) -- Optimization: immediate values (imm32) can be stored directly
vderef(tmp_reg, tmp_reg, V[a].const.__dissector) if type(const) == 'number' and w < 8 then
src_reg = tmp_reg -- Materialize and dereference it emit(BPF.MEM + BPF.ST + const_width[w], dst_reg, 0, ofs, const)
-- Source is a value on stack, we must load it first else
elseif V[a].const and V[a].const.__base > 0 then emit(BPF.MEM + BPF.STX + const_width[w], dst_reg, src_reg, ofs, 0)
emit(BPF.MEM + BPF.LDX + const_width[w], src_reg, 10, -V[a].const.__base, 0)
V[a].type = V[a].const.__dissector
V[a].const = nil -- Value is dereferenced
end end
-- If the table is not on stack, it must be checked for NULL -- Table is already on stack, write to base-relative address
if not base.__base then elseif base.__base then
emit(BPF.JMP + BPF.JEQ + BPF.K, V[b].reg, 0, 1, 0) -- if (map[x] != NULL) -- Optimization: immediate values (imm32) can be stored directly
emit(BPF.MEM + BPF.STX + const_width[w], V[b].reg, src_reg, ofs, 0) if type(const) == 'number' and w < 8 then
else -- Table is already on stack, write to base-relative address emit(BPF.MEM + BPF.ST + const_width[w], 10, 0, -base.__base + ofs, const)
else
emit(BPF.MEM + BPF.STX + const_width[w], 10, src_reg, -base.__base + ofs, 0) emit(BPF.MEM + BPF.STX + const_width[w], 10, src_reg, -base.__base + ofs, 0)
end end
else
error('NYI: B[C] where B is not Lua table, BPF map, or pointer')
end
elseif V[a].const then elseif V[a].const then
base[c] = V[a].const base[c] = V[a].const
else error('NYI: B[C] = A, where B is not Lua table or BPF map') else
error('NYI: B[C] where B is not Lua table, BPF map, or pointer')
end
end,
TGETB = function (a, b, _, d) -- TGETB (A = B[D])
local base = V[b].const
assert(type(base) == 'table', 'NYI: B[C] where C is string and B not Lua table or BPF map')
if a ~= b then vset(a) end
if base.__map then -- BPF map read (constant)
MAP_GET(a, b, nil, d)
-- Pointer access with a dissector (traditional uses BPF_LD, direct uses BPF_MEM)
elseif V[b].source and V[b].source:find('ptr_to_') then
local vtype = base.__dissector and base.__dissector or ffi.typeof('uint8_t')
LOAD(a, b, d, vtype)
-- Specialise PTR[0] as dereference operator
elseif cdef.isptr(V[b].type) and d == 0 then
vcopy(a, b)
local dst_reg = vreg(a)
vderef(dst_reg, dst_reg, V[a])
V[a].type = V[a].const.__dissector
else
error('NYI: A = B[D], where B is not Lua table or packet dissector or pointer dereference')
end end
end, end,
TGETV = function (a, b, _, d) -- TGETV (A = B[D]) TGETV = function (a, b, _, d) -- TGETV (A = B[D])
assert(V[b] and V[b].const, 'NYI: B[D] where B is not Lua table or BPF map') local base = V[b].const
assert(type(base) == 'table', 'NYI: B[C] where C is string and B not Lua table or BPF map')
if a ~= b then vset(a) end if a ~= b then vset(a) end
if V[b].const.__map then -- BPF map read if base.__map then -- BPF map read
MAP_GET(a, b, d) MAP_GET(a, b, d)
elseif V[b].const == env.pkt then -- Raw packet, no offset -- Pointer access with a dissector (traditional uses BPF_LD, direct uses BPF_MEM)
LD_FIELD(a, d, 1, V[d].const) elseif V[b].source and V[b].source:find('ptr_to_') then
else V[a].const = V[b].const[V[d].const] end local vtype = base.__dissector and base.__dissector or ffi.typeof('uint8_t')
LOAD(a, b, d, vtype)
-- Constant dereference
elseif type(V[d].const) == 'number' then
V[a].const = base[V[d].const]
else
error('NYI: A = B[D], where B is not Lua table or packet dissector or pointer dereference')
end
end, end,
TGETS = function (a, b, c, _) -- TGETS (A = B[C]) TGETS = function (a, b, c, _) -- TGETS (A = B[C])
assert(V[b] and V[b].const, 'NYI: B[C] where C is string and B not Lua table or BPF map')
local base = V[b].const local base = V[b].const
if type(base) == 'table' and base.__dissector then assert(type(base) == 'table', 'NYI: B[C] where C is string and B not Lua table or BPF map')
if a ~= b then vset(a) end
if base.__dissector then
local ofs,bpos,bsize = ffi.offsetof(base.__dissector, c) local ofs,bpos,bsize = ffi.offsetof(base.__dissector, c)
-- Resolve table key using metatable -- Resolve table key using metatable
if not ofs and type(base.__dissector[c]) == 'string' then if not ofs and type(base.__dissector[c]) == 'string' then
...@@ -824,31 +1023,28 @@ local BC = { ...@@ -824,31 +1023,28 @@ local BC = {
if not ofs and proto[c] then -- Load new dissector on given offset if not ofs and proto[c] then -- Load new dissector on given offset
BUILTIN(proto[c], a, b, c) BUILTIN(proto[c], a, b, c)
else else
-- Loading register from offset is a little bit tricky as there are
-- several data sources and value loading modes with different restrictions
-- such as checking pointer values for NULL compared to using stack.
assert(ofs, tostring(base.__dissector)..'.'..c..' attribute not exists') assert(ofs, tostring(base.__dissector)..'.'..c..' attribute not exists')
if a ~= b then vset(a) end if a ~= b then vset(a) end
-- Dissected value is probably not constant anymore -- Dissected value is probably not constant anymore
local new_const = nil local new_const = nil
-- Simple register load, get absolute offset or R-relative local w, atype = builtins.sizeofattr(base.__dissector, c)
local w, atype = sizeofattr(base.__dissector, c) -- [SP+K] addressing using R10 (stack pointer)
if base.__base == true then -- R-relative addressing -- Doesn't need to be checked for NULL
local dst_reg = vreg(a, nil, true) if base.__base and base.__base > 0 then
assert(const_width[w], 'NYI: sizeof('..tostring(base.__dissector)..'.'..c..') not 1/2/4/8 bytes')
emit(BPF.MEM + BPF.LDX + const_width[w], dst_reg, V[b].reg, ofs, 0)
elseif not base.source and base.__base and base.__base > 0 then -- [FP+K] addressing
if cdef.isptr(atype) then -- If the member is pointer type, update base pointer with offset if cdef.isptr(atype) then -- If the member is pointer type, update base pointer with offset
new_const = {__base = base.__base-ofs} new_const = {__base = base.__base-ofs}
else else
local dst_reg = vreg(a, nil, true) local dst_reg = vreg(a, nil, true)
emit(BPF.MEM + BPF.LDX + const_width[w], dst_reg, 10, -base.__base+ofs, 0) emit(BPF.MEM + BPF.LDX + const_width[w], dst_reg, 10, -base.__base+ofs, 0)
end end
elseif base.off then -- Absolute address to payload -- Pointer access with a dissector (traditional uses BPF_LD, direct uses BPF_MEM)
LD_ABS(a, ofs + base.off, w) elseif V[b].source and V[b].source:find('ptr_to_') then
elseif base.source == 'probe' then -- Indirect read using probe LOAD(a, b, ofs, atype)
BUILTIN(builtins[builtins.probe_read], nil, a, b, atype, ofs) else
V[a].source = V[b].source -- Builtin handles everything error('NYI: B[C] where B is not Lua table, BPF map, or pointer')
return
else -- Indirect address to payload
LD_IND(a, b, w, ofs)
end end
-- Bitfield, must be further narrowed with a bitmask/shift -- Bitfield, must be further narrowed with a bitmask/shift
if bpos then if bpos then
...@@ -868,8 +1064,20 @@ local BC = { ...@@ -868,8 +1064,20 @@ local BC = {
V[a].type = atype V[a].type = atype
V[a].const = new_const V[a].const = new_const
V[a].source = V[b].source V[a].source = V[b].source
-- Track direct access to skb data
-- see https://www.kernel.org/doc/Documentation/networking/filter.txt "Direct packet access"
if ffi.istype(base.__dissector, ffi.typeof('struct sk_buff')) then
-- Direct access to skb uses skb->data and skb->data_end
-- which are encoded as u32, but are actually pointers
if c == 'data' or c == 'data_end' then
V[a].const = {__dissector = ffi.typeof('uint8_t')}
V[a].source = 'ptr_to_skb'
end
end
end
else
V[a].const = base[c]
end end
else V[a].const = base[c] end
end, end,
-- Loops and branches -- Loops and branches
CALLM = function (a, b, _, d) -- A = A(A+1, ..., A+D+MULTRES) CALLM = function (a, b, _, d) -- A = A(A+1, ..., A+D+MULTRES)
...@@ -882,8 +1090,11 @@ local BC = { ...@@ -882,8 +1090,11 @@ local BC = {
JMP = function (a, _, c, _) -- JMP JMP = function (a, _, c, _) -- JMP
-- Discard unused slots after jump -- Discard unused slots after jump
for i, _ in pairs(V) do for i, _ in pairs(V) do
if i >= a then V[i] = {} end if i >= a and i < stackslots then
V[i] = nil
end
end end
-- Cross basic block boundary if the jump target isn't provably unreachable
local val = code.fixup[c] or {} local val = code.fixup[c] or {}
if code.seen_cmp and code.seen_cmp ~= ALWAYS then if code.seen_cmp and code.seen_cmp ~= ALWAYS then
if code.seen_cmp ~= NEVER then -- Do not emit the jump or fixup if code.seen_cmp ~= NEVER then -- Do not emit the jump or fixup
...@@ -893,25 +1104,49 @@ local BC = { ...@@ -893,25 +1104,49 @@ local BC = {
-- First branch point, emit compensation code -- First branch point, emit compensation code
local Vcomp = Vstate[c] local Vcomp = Vstate[c]
if not Vcomp then if not Vcomp then
for i,v in pairs(V) do -- Select scratch register (R0-5) that isn't used as operand
if not v.reg and v.const and not is_proxy(v.const) then -- in the CMP instruction, as the variable may not be live, after
vreg(i, 0) -- Load to TMP register (not saved) -- the JMP, but it may be used in the JMP+CMP instruction itself
local tmp_reg = 0
for reg = 0, 5 do
if reg ~= jmpi.dst_reg and reg ~= jmpi.src_reg then
tmp_reg = reg
break
end
end end
if v.reg and v.reg <= 5 then -- Force materialization of constants at the end of BB
for i, v in pairs(V) do
if not v.reg and cdef.isimmconst(v) then
vreg(i, tmp_reg) -- Load to TMP register (not saved)
reg_spill(i) -- Spill caller-saved registers reg_spill(i) -- Spill caller-saved registers
end end
end end
-- Record variable state -- Record variable state
Vstate[c] = V Vstate[c] = V
V = {} Vcomp = V
for i,v in pairs(Vstate[c]) do V = table_copy(V)
V[i] = {} -- Variable state already set, emit specific compensation code
for k,e in pairs(v) do else
V[i][k] = e bb_end(Vcomp)
end
-- Record pointer NULL check from condition
-- If the condition checks pointer variable against NULL,
-- we can assume it will not be NULL in the fall-through block
if code.seen_null_guard then
local var = code.seen_null_guard
-- The null guard can have two forms:
-- if x == nil then goto
-- if x ~= nil then goto
-- First form guarantees that the variable will be non-nil on the following instruction
-- Second form guarantees that the variable will be non-nil at the jump target
local vinfo = code.seen_null_guard_inverse and Vcomp[var] or V[var]
if vinfo.source then
local pos = vinfo.source:find('_or_null', 1, true)
if pos then
vinfo.source = vinfo.source:sub(1, pos - 1)
end
end end
end end
-- Variable state already set, emit specific compensation code
else bb_end(Vcomp) end
-- Reemit CMP insn -- Reemit CMP insn
emit(jmpi.code, jmpi.dst_reg, jmpi.src_reg, jmpi.off, jmpi.imm) emit(jmpi.code, jmpi.dst_reg, jmpi.src_reg, jmpi.off, jmpi.imm)
-- Fuse JMP into previous CMP opcode, mark JMP target for fixup -- Fuse JMP into previous CMP opcode, mark JMP target for fixup
...@@ -920,22 +1155,53 @@ local BC = { ...@@ -920,22 +1155,53 @@ local BC = {
code.fixup[c] = val code.fixup[c] = val
end end
code.seen_cmp = nil code.seen_cmp = nil
code.seen_null_guard = nil
code.seen_null_guard_inverse = nil
elseif c == code.bc_pc + 1 then -- luacheck: ignore 542
-- Eliminate jumps to next immediate instruction
-- e.g. 0002 JMP 1 => 0003
else
-- We need to synthesise a condition that's always true, however
-- BPF prohibits pointer arithmetic to prevent pointer leaks
-- so we have to clear out one register and use it for cmp that's always true
local dst_reg = reg_alloc(stackslots)
V[stackslots].reg = nil -- Only temporary allocation
-- First branch point, emit compensation code
local Vcomp = Vstate[c]
if not Vcomp then
-- Force materialization of constants at the end of BB
for i, v in pairs(V) do
if not v.reg and cdef.isimmconst(v) then
vreg(i, dst_reg) -- Load to TMP register (not saved)
reg_spill(i) -- Spill caller-saved registers
end
end
-- Record variable state
Vstate[c] = V
V = table_copy(V)
-- Variable state already set, emit specific compensation code
else else
emit(BPF.JMP + BPF.JEQ + BPF.X, 6, 6, 0xffff, 0) -- Always true bb_end(Vcomp)
end
emit(BPF.ALU64 + BPF.MOV + BPF.K, dst_reg, 0, 0, 0)
emit(BPF.JMP + BPF.JEQ + BPF.K, dst_reg, 0, 0xffff, 0)
table.insert(val, code.pc-1) -- Fixup JMP target table.insert(val, code.pc-1) -- Fixup JMP target
code.reachable = false -- Code following the JMP is not reachable code.reachable = false -- Code following the JMP is not reachable
code.fixup[c] = val code.fixup[c] = val
end end
end, end,
RET1 = function (a, _, _, _) -- RET1 RET1 = function (a, _, _, _) -- RET1
-- Free optimisation: spilled variable will not be filled again
for i, v in pairs(V) do
if i ~= a then v.reg = nil end
end
if V[a].reg ~= 0 then vreg(a, 0) end if V[a].reg ~= 0 then vreg(a, 0) end
-- Dereference pointer variables -- Convenience: dereference pointer variables
-- e.g. 'return map[k]' will return actual map value, not pointer
if cdef.isptr(V[a].type) then if cdef.isptr(V[a].type) then
vderef(0, 0, V[a].const.__dissector) vderef(0, 0, V[a])
end end
emit(BPF.JMP + BPF.EXIT, 0, 0, 0, 0) emit(BPF.JMP + BPF.EXIT, 0, 0, 0, 0)
-- Free optimisation: spilled variable will not be filled again
for _,v in pairs(V) do if v.reg == 0 then v.reg = nil end end
code.reachable = false code.reachable = false
end, end,
RET0 = function (_, _, _, _) -- RET0 RET0 = function (_, _, _, _) -- RET0
...@@ -947,12 +1213,19 @@ local BC = { ...@@ -947,12 +1213,19 @@ local BC = {
return code return code
end end
} }
-- Composite instructions
function BC.CALLT(a, _, _, d) -- Tailcall: return A(A+1, ..., A+D-1)
CALL(a, 1, d)
BC.RET1(a)
end
-- Always initialize R6 with R1 context -- Always initialize R6 with R1 context
emit(BPF.ALU64 + BPF.MOV + BPF.X, 6, 1, 0, 0) emit(BPF.ALU64 + BPF.MOV + BPF.X, 6, 1, 0, 0)
-- Register R6 as context variable (first argument) -- Register R6 as context variable (first argument)
if params and params > 0 then if params and params > 0 then
vset(0, 6, param_types[1] or proto.skb) vset(0, 6, param_types[1] or proto.skb)
V[0].source = V[0].const.source -- Propagate source annotation from typeinfo assert(V[0].source == V[0].const.source) -- Propagate source annotation from typeinfo
end end
-- Register tmpvars -- Register tmpvars
vset(stackslots) vset(stackslots)
...@@ -967,8 +1240,22 @@ return setmetatable(BC, { ...@@ -967,8 +1240,22 @@ return setmetatable(BC, {
__call = function (t, op, a, b, c, d) __call = function (t, op, a, b, c, d)
code.bc_pc = code.bc_pc + 1 code.bc_pc = code.bc_pc + 1
-- Exitting BB straight through, emit compensation code -- Exitting BB straight through, emit compensation code
if Vstate[code.bc_pc] and code.reachable then if Vstate[code.bc_pc] then
if code.reachable then
-- Instruction is reachable from previous line
-- so we must make the variable allocation consistent
-- with the variable allocation at the jump source
-- e.g. 0001 x:R0 = 5
-- 0002 if rand() then goto 0005
-- 0003 x:R0 -> x:stack
-- 0004 y:R0 = 5
-- 0005 x:? = 10 <-- x was in R0 before jump, and stack after jump
bb_end(Vstate[code.bc_pc]) bb_end(Vstate[code.bc_pc])
else
-- Instruction isn't reachable from previous line, restore variable layout
-- e.g. RET or condition-less JMP on previous line
V = table_copy(Vstate[code.bc_pc])
end
end end
-- Perform fixup of jump targets -- Perform fixup of jump targets
-- We need to do this because the number of consumed and emitted -- We need to do this because the number of consumed and emitted
...@@ -1004,8 +1291,8 @@ local function dump_mem(cls, ins, _, fuse) ...@@ -1004,8 +1291,8 @@ local function dump_mem(cls, ins, _, fuse)
local dst = cls < 2 and 'R'..ins.dst_reg or string.format('[R%d%+d]', ins.dst_reg, off) local dst = cls < 2 and 'R'..ins.dst_reg or string.format('[R%d%+d]', ins.dst_reg, off)
local src = cls % 2 == 0 and '#'..ins.imm or 'R'..ins.src_reg local src = cls % 2 == 0 and '#'..ins.imm or 'R'..ins.src_reg
if cls == BPF.LDX then src = string.format('[R%d%+d]', ins.src_reg, off) end if cls == BPF.LDX then src = string.format('[R%d%+d]', ins.src_reg, off) end
if mode == BPF.ABS then src = string.format('[%d]', ins.imm) end if mode == BPF.ABS then src = string.format('skb[%d]', ins.imm) end
if mode == BPF.IND then src = string.format('[R%d%+d]', ins.src_reg, ins.imm) end if mode == BPF.IND then src = string.format('skb[R%d%+d]', ins.src_reg, ins.imm) end
return string.format('%s\t%s\t%s', fuse and '' or name, fuse and '' or dst, src) return string.format('%s\t%s\t%s', fuse and '' or name, fuse and '' or dst, src)
end end
...@@ -1029,26 +1316,37 @@ local function dump_alu(cls, ins, pc) ...@@ -1029,26 +1316,37 @@ local function dump_alu(cls, ins, pc)
return string.format('%s\t%s\t%s%s', name, 'R'..ins.dst_reg, src, target) return string.format('%s\t%s\t%s%s', name, 'R'..ins.dst_reg, src, target)
end end
local function dump(code) local function dump_string(code, off, hide_counter)
if not code then return end if not code then return end
print(string.format('-- BPF %s:0-%u', code.insn, code.pc))
local cls_map = { local cls_map = {
[0] = dump_mem, [1] = dump_mem, [2] = dump_mem, [3] = dump_mem, [0] = dump_mem, [1] = dump_mem, [2] = dump_mem, [3] = dump_mem,
[4] = dump_alu, [5] = dump_alu, [7] = dump_alu, [4] = dump_alu, [5] = dump_alu, [7] = dump_alu,
} }
local result = {}
local fused = false local fused = false
for i = 0, code.pc - 1 do for i = off or 0, code.pc - 1 do
local ins = code.insn[i] local ins = code.insn[i]
local cls = bit.band(ins.code, 0x07) local cls = bit.band(ins.code, 0x07)
local line = cls_map[cls](cls, ins, i, fused) local line = cls_map[cls](cls, ins, i, fused)
print(string.format('%04u\t%s', i, line)) if hide_counter then
table.insert(result, line)
else
table.insert(result, string.format('%04u\t%s', i, line))
end
fused = string.find(line, 'LDDW', 1) fused = string.find(line, 'LDDW', 1)
end end
return table.concat(result, '\n')
end
local function dump(code)
if not code then return end
print(string.format('-- BPF %s:0-%u', code.insn, code.pc))
print(dump_string(code))
end end
local function compile(prog, params) local function compile(prog, params)
-- Create code emitter sandbox, include caller locals -- Create code emitter sandbox, include caller locals
local env = { pkt=proto.pkt, BPF=BPF, ffi=ffi } local env = { pkt=proto.pkt, eth=proto.pkt, BPF=BPF, ffi=ffi }
-- Include upvalues up to 4 nested scopes back -- Include upvalues up to 4 nested scopes back
-- the narrower scope overrides broader scope -- the narrower scope overrides broader scope
for k = 5, 2, -1 do for k = 5, 2, -1 do
...@@ -1082,9 +1380,9 @@ local function compile(prog, params) ...@@ -1082,9 +1380,9 @@ local function compile(prog, params)
print(debug.traceback()) print(debug.traceback())
end end
for _,op,a,b,c,d in bytecode.decoder(prog) do for _,op,a,b,c,d in bytecode.decoder(prog) do
local ok, res, err = xpcall(E,on_err,op,a,b,c,d) local ok, _, err = xpcall(E,on_err,op,a,b,c,d)
if not ok then if not ok then
return nil, res, err return nil, err
end end
end end
return E:compile() return E:compile()
...@@ -1173,8 +1471,8 @@ local tracepoint_mt = { ...@@ -1173,8 +1471,8 @@ local tracepoint_mt = {
__index = { __index = {
bpf = function (t, prog) bpf = function (t, prog)
if type(prog) ~= 'table' then if type(prog) ~= 'table' then
-- Create protocol parser with source=probe -- Create protocol parser with source probe
prog = compile(prog, {proto.type(t.type, {source='probe'})}) prog = compile(prog, {proto.type(t.type, {source='ptr_to_probe'})})
end end
-- Load the BPF program -- Load the BPF program
local prog_fd, err, log = S.bpf_prog_load(S.c.BPF_PROG.TRACEPOINT, prog.insn, prog.pc) local prog_fd, err, log = S.bpf_prog_load(S.c.BPF_PROG.TRACEPOINT, prog.insn, prog.pc)
...@@ -1236,6 +1534,7 @@ end ...@@ -1236,6 +1534,7 @@ end
return setmetatable({ return setmetatable({
new = create_emitter, new = create_emitter,
dump = dump, dump = dump,
dump_string = dump_string,
maps = {}, maps = {},
map = function (type, max_entries, key_ctype, val_ctype) map = function (type, max_entries, key_ctype, val_ctype)
if not key_ctype then key_ctype = ffi.typeof('uint32_t') end if not key_ctype then key_ctype = ffi.typeof('uint32_t') end
...@@ -1271,7 +1570,11 @@ return setmetatable({ ...@@ -1271,7 +1570,11 @@ return setmetatable({
ok, err = sock:bind(S.t.sockaddr_ll({protocol='all', ifindex=iface.index})) ok, err = sock:bind(S.t.sockaddr_ll({protocol='all', ifindex=iface.index}))
assert(ok, tostring(err)) assert(ok, tostring(err))
elseif type(sock) == 'number' then elseif type(sock) == 'number' then
sock = assert(S.t.socket(sock)) sock = S.t.fd(sock):nogc()
elseif ffi.istype(S.t.fd, sock) then -- luacheck: ignore
-- No cast required
else
return nil, 'socket must either be an fd number, an interface name, or an ljsyscall socket'
end end
-- Load program and attach it to socket -- Load program and attach it to socket
if type(prog) ~= 'table' then if type(prog) ~= 'table' then
......
...@@ -44,6 +44,22 @@ local function width_type(w) ...@@ -44,6 +44,22 @@ local function width_type(w)
end end
builtins.width_type = width_type builtins.width_type = width_type
-- Return struct member size/type (requires LuaJIT 2.1+)
-- I am ashamed that there's no easier way around it.
local function sizeofattr(ct, name)
if not ffi.typeinfo then error('LuaJIT 2.1+ is required for ffi.typeinfo') end
local cinfo = ffi.typeinfo(ct)
while true do
cinfo = ffi.typeinfo(cinfo.sib)
if not cinfo then return end
if cinfo.name == name then break end
end
local size = math.max(1, ffi.typeinfo(cinfo.sib or ct).size - cinfo.size)
-- Guess type name
return size, builtins.width_type(size)
end
builtins.sizeofattr = sizeofattr
-- Byte-order conversions for little endian -- Byte-order conversions for little endian
local function ntoh(x, w) local function ntoh(x, w)
if w then x = ffi.cast(const_width_type[w/8], x) end if w then x = ffi.cast(const_width_type[w/8], x) end
...@@ -76,21 +92,34 @@ if ffi.abi('be') then ...@@ -76,21 +92,34 @@ if ffi.abi('be') then
return w and ffi.cast(const_width_type[w/8], x) or x return w and ffi.cast(const_width_type[w/8], x) or x
end end
hton = ntoh hton = ntoh
builtins[ntoh] = function(a, b, w) return end builtins[ntoh] = function(_, _, _) return end
builtins[hton] = function(a, b, w) return end builtins[hton] = function(_, _, _) return end
end end
-- Other built-ins -- Other built-ins
local function xadd() error('NYI') end local function xadd() error('NYI') end
builtins.xadd = xadd builtins.xadd = xadd
builtins[xadd] = function (e, dst, a, b, off) builtins[xadd] = function (e, ret, a, b, off)
assert(e.V[a].const.__dissector, 'xadd(a, b) called on non-pointer') local vinfo = e.V[a].const
local w = ffi.sizeof(e.V[a].const.__dissector) assert(vinfo and vinfo.__dissector, 'xadd(a, b[, offset]) called on non-pointer')
local w = ffi.sizeof(vinfo.__dissector)
-- Calculate structure attribute offsets
if e.V[off] and type(e.V[off].const) == 'string' then
local ct, field = vinfo.__dissector, e.V[off].const
off = ffi.offsetof(ct, field)
assert(off, 'xadd(a, b, offset) - offset is not valid in given structure')
w = sizeofattr(ct, field)
end
assert(w == 4 or w == 8, 'NYI: xadd() - 1 and 2 byte atomic increments are not supported') assert(w == 4 or w == 8, 'NYI: xadd() - 1 and 2 byte atomic increments are not supported')
-- Allocate registers and execute -- Allocate registers and execute
e.vcopy(dst, a)
local src_reg = e.vreg(b) local src_reg = e.vreg(b)
local dst_reg = e.vreg(dst) local dst_reg = e.vreg(a)
-- Set variable for return value and call
e.vset(ret)
e.vreg(ret, 0, true, ffi.typeof('int32_t'))
-- Optimize the NULL check away if provably not NULL
if not e.V[a].source or e.V[a].source:find('_or_null', 1, true) then
e.emit(BPF.JMP + BPF.JEQ + BPF.K, dst_reg, 0, 1, 0) -- if (dst != NULL) e.emit(BPF.JMP + BPF.JEQ + BPF.K, dst_reg, 0, 1, 0) -- if (dst != NULL)
end
e.emit(BPF.XADD + BPF.STX + const_width[w], dst_reg, src_reg, off or 0, 0) e.emit(BPF.XADD + BPF.STX + const_width[w], dst_reg, src_reg, off or 0, 0)
end end
...@@ -137,11 +166,10 @@ end ...@@ -137,11 +166,10 @@ end
builtins[ffi.cast] = function (e, dst, ct, x) builtins[ffi.cast] = function (e, dst, ct, x)
assert(e.V[ct].const, 'ffi.cast(ctype, x) called with bad ctype') assert(e.V[ct].const, 'ffi.cast(ctype, x) called with bad ctype')
e.vcopy(dst, x) e.vcopy(dst, x)
if not e.V[x].const then if e.V[x].const and type(e.V[x].const) == 'table' then
e.V[dst].type = ffi.typeof(e.V[ct].const)
else
e.V[dst].const.__dissector = ffi.typeof(e.V[ct].const) e.V[dst].const.__dissector = ffi.typeof(e.V[ct].const)
end end
e.V[dst].type = ffi.typeof(e.V[ct].const)
-- Specific types also encode source of the data -- Specific types also encode source of the data
-- This is because BPF has different helpers for reading -- This is because BPF has different helpers for reading
-- different data sources, so variables must track origins. -- different data sources, so variables must track origins.
...@@ -149,7 +177,7 @@ builtins[ffi.cast] = function (e, dst, ct, x) ...@@ -149,7 +177,7 @@ builtins[ffi.cast] = function (e, dst, ct, x)
-- struct skb - source of the data is socket buffer -- struct skb - source of the data is socket buffer
-- struct X - source of the data is probe/tracepoint -- struct X - source of the data is probe/tracepoint
if ffi.typeof(e.V[ct].const) == ffi.typeof('struct pt_regs') then if ffi.typeof(e.V[ct].const) == ffi.typeof('struct pt_regs') then
e.V[dst].source = 'probe' e.V[dst].source = 'ptr_to_probe'
end end
end end
...@@ -160,7 +188,14 @@ builtins[ffi.new] = function (e, dst, ct, x) ...@@ -160,7 +188,14 @@ builtins[ffi.new] = function (e, dst, ct, x)
assert(not x, 'NYI: ffi.new(ctype, ...) - initializer is not supported') assert(not x, 'NYI: ffi.new(ctype, ...) - initializer is not supported')
assert(not cdef.isptr(ct, true), 'NYI: ffi.new(ctype, ...) - ctype MUST NOT be a pointer') assert(not cdef.isptr(ct, true), 'NYI: ffi.new(ctype, ...) - ctype MUST NOT be a pointer')
e.vset(dst, nil, ct) e.vset(dst, nil, ct)
e.V[dst].source = 'ptr_to_stack'
e.V[dst].const = {__base = e.valloc(ffi.sizeof(ct), true), __dissector = ct} e.V[dst].const = {__base = e.valloc(ffi.sizeof(ct), true), __dissector = ct}
-- Set array dissector if created an array
-- e.g. if ct is 'char [2]', then dissector is 'char'
local elem_type = tostring(ct):match('ctype<(.+)%s%[(%d+)%]>')
if elem_type then
e.V[dst].const.__dissector = ffi.typeof(elem_type)
end
end end
builtins[ffi.copy] = function (e, ret, dst, src) builtins[ffi.copy] = function (e, ret, dst, src)
...@@ -169,7 +204,7 @@ builtins[ffi.copy] = function (e, ret, dst, src) ...@@ -169,7 +204,7 @@ builtins[ffi.copy] = function (e, ret, dst, src)
-- Specific types also encode source of the data -- Specific types also encode source of the data
-- struct pt_regs - source of the data is probe -- struct pt_regs - source of the data is probe
-- struct skb - source of the data is socket buffer -- struct skb - source of the data is socket buffer
if e.V[src].source == 'probe' then if e.V[src].source and e.V[src].source:find('ptr_to_probe', 1, true) then
e.reg_alloc(e.tmpvar, 1) e.reg_alloc(e.tmpvar, 1)
-- Load stack pointer to dst, since only load to stack memory is supported -- Load stack pointer to dst, since only load to stack memory is supported
-- we have to either use spilled variable or allocated stack memory offset -- we have to either use spilled variable or allocated stack memory offset
...@@ -221,7 +256,8 @@ builtins[print] = function (e, ret, fmt, a1, a2, a3) ...@@ -221,7 +256,8 @@ builtins[print] = function (e, ret, fmt, a1, a2, a3)
-- TODO: this is materialize step -- TODO: this is materialize step
e.V[fmt].const = {__base=dst} e.V[fmt].const = {__base=dst}
e.V[fmt].type = ffi.typeof('char ['..len..']') e.V[fmt].type = ffi.typeof('char ['..len..']')
elseif e.V[fmt].const.__base then -- NOP elseif e.V[fmt].const.__base then -- luacheck: ignore
-- NOP
else error('NYI: print(fmt, ...) - format variable is not literal/stack memory') end else error('NYI: print(fmt, ...) - format variable is not literal/stack memory') end
-- Prepare helper call -- Prepare helper call
e.emit(BPF.ALU64 + BPF.MOV + BPF.X, 1, 10, 0, 0) e.emit(BPF.ALU64 + BPF.MOV + BPF.X, 1, 10, 0, 0)
...@@ -270,7 +306,6 @@ end ...@@ -270,7 +306,6 @@ end
-- Implements bpf_skb_load_bytes(ctx, off, var, vlen) on skb->data -- Implements bpf_skb_load_bytes(ctx, off, var, vlen) on skb->data
local function load_bytes(e, dst, off, var) local function load_bytes(e, dst, off, var)
print(e.V[off].const, e.V[var].const)
-- Set R2 = offset -- Set R2 = offset
e.vset(e.tmpvar, nil, off) e.vset(e.tmpvar, nil, off)
e.vreg(e.tmpvar, 2, false, ffi.typeof('uint64_t')) e.vreg(e.tmpvar, 2, false, ffi.typeof('uint64_t'))
...@@ -350,7 +385,7 @@ builtins[math.log2] = function (e, dst, x) ...@@ -350,7 +385,7 @@ builtins[math.log2] = function (e, dst, x)
e.vcopy(e.tmpvar, x) e.vcopy(e.tmpvar, x)
local v = e.vreg(e.tmpvar, 2) local v = e.vreg(e.tmpvar, 2)
if cdef.isptr(e.V[x].const) then -- No pointer arithmetics, dereference if cdef.isptr(e.V[x].const) then -- No pointer arithmetics, dereference
e.vderef(v, v, ffi.typeof('uint64_t')) e.vderef(v, v, {const = {__dissector=ffi.typeof('uint64_t')}})
end end
-- Invert value to invert all tests, otherwise we would need and+jnz -- Invert value to invert all tests, otherwise we would need and+jnz
e.emit(BPF.ALU64 + BPF.NEG + BPF.K, v, 0, 0, 0) -- v = ~v e.emit(BPF.ALU64 + BPF.NEG + BPF.K, v, 0, 0, 0) -- v = ~v
...@@ -386,9 +421,9 @@ builtins[math.log] = function (e, dst, x) ...@@ -386,9 +421,9 @@ builtins[math.log] = function (e, dst, x)
end end
-- Call-type helpers -- Call-type helpers
local function call_helper(e, dst, h) local function call_helper(e, dst, h, vtype)
e.vset(dst) e.vset(dst)
e.vreg(dst, 0, true) e.vreg(dst, 0, true, vtype or ffi.typeof('uint64_t'))
e.emit(BPF.JMP + BPF.CALL, 0, 0, 0, h) e.emit(BPF.JMP + BPF.CALL, 0, 0, 0, h)
e.V[dst].const = nil -- Target is not a function anymore e.V[dst].const = nil -- Target is not a function anymore
end end
...@@ -408,7 +443,7 @@ builtins.perf_submit = perf_submit ...@@ -408,7 +443,7 @@ builtins.perf_submit = perf_submit
builtins.stack_id = stack_id builtins.stack_id = stack_id
builtins.load_bytes = load_bytes builtins.load_bytes = load_bytes
builtins[cpu] = function (e, dst) return call_helper(e, dst, HELPER.get_smp_processor_id) end builtins[cpu] = function (e, dst) return call_helper(e, dst, HELPER.get_smp_processor_id) end
builtins[rand] = function (e, dst) return call_helper(e, dst, HELPER.get_prandom_u32) end builtins[rand] = function (e, dst) return call_helper(e, dst, HELPER.get_prandom_u32, ffi.typeof('uint32_t')) end
builtins[time] = function (e, dst) return call_helper(e, dst, HELPER.ktime_get_ns) end builtins[time] = function (e, dst) return call_helper(e, dst, HELPER.ktime_get_ns) end
builtins[pid_tgid] = function (e, dst) return call_helper(e, dst, HELPER.get_current_pid_tgid) end builtins[pid_tgid] = function (e, dst) return call_helper(e, dst, HELPER.get_current_pid_tgid) end
builtins[uid_gid] = function (e, dst) return call_helper(e, dst, HELPER.get_current_uid_gid) end builtins[uid_gid] = function (e, dst) return call_helper(e, dst, HELPER.get_current_uid_gid) end
......
...@@ -15,7 +15,7 @@ limitations under the License. ...@@ -15,7 +15,7 @@ limitations under the License.
]] ]]
local ffi = require('ffi') local ffi = require('ffi')
local bit = require('bit') local bit = require('bit')
local S = require('syscall') local has_syscall, S = pcall(require, 'syscall')
local M = {} local M = {}
ffi.cdef [[ ffi.cdef [[
...@@ -132,23 +132,54 @@ struct bpf_stacktrace { ...@@ -132,23 +132,54 @@ struct bpf_stacktrace {
]] ]]
-- Compatibility: ljsyscall doesn't have support for BPF syscall -- Compatibility: ljsyscall doesn't have support for BPF syscall
if not S.bpf then if not has_syscall or not S.bpf then
error("ljsyscall doesn't support bpf(), must be updated") error("ljsyscall doesn't support bpf(), must be updated")
else else
local strflag = require('syscall.helpers').strflag
-- Compatibility: ljsyscall<=0.12 -- Compatibility: ljsyscall<=0.12
if not S.c.BPF_MAP.PERCPU_HASH then if not S.c.BPF_MAP.LRU_HASH then
S.c.BPF_MAP.PERCPU_HASH = 5 S.c.BPF_MAP = strflag {
S.c.BPF_MAP.PERCPU_ARRAY = 6 UNSPEC = 0,
S.c.BPF_MAP.STACK_TRACE = 7 HASH = 1,
S.c.BPF_MAP.CGROUP_ARRAY = 8 ARRAY = 2,
S.c.BPF_MAP.LRU_HASH = 9 PROG_ARRAY = 3,
S.c.BPF_MAP.LRU_PERCPU_HASH = 10 PERF_EVENT_ARRAY = 4,
S.c.BPF_MAP.LPM_TRIE = 11 PERCPU_HASH = 5,
PERCPU_ARRAY = 6,
STACK_TRACE = 7,
CGROUP_ARRAY = 8,
LRU_HASH = 9,
LRU_PERCPU_HASH = 10,
LPM_TRIE = 11,
ARRAY_OF_MAPS = 12,
HASH_OF_MAPS = 13,
DEVMAP = 14,
SOCKMAP = 15,
CPUMAP = 16,
}
end end
if not S.c.BPF_PROG.TRACEPOINT then if not S.c.BPF_PROG.TRACEPOINT then
S.c.BPF_PROG.TRACEPOINT = 5 S.c.BPF_PROG = strflag {
S.c.BPF_PROG.XDP = 6 UNSPEC = 0,
S.c.BPF_PROG.PERF_EVENT = 7 SOCKET_FILTER = 1,
KPROBE = 2,
SCHED_CLS = 3,
SCHED_ACT = 4,
TRACEPOINT = 5,
XDP = 6,
PERF_EVENT = 7,
CGROUP_SKB = 8,
CGROUP_SOCK = 9,
LWT_IN = 10,
LWT_OUT = 11,
LWT_XMIT = 12,
SOCK_OPS = 13,
SK_SKB = 14,
CGROUP_DEVICE = 15,
SK_MSG = 16,
RAW_TRACEPOINT = 17,
CGROUP_SOCK_ADDR = 18,
}
end end
end end
...@@ -180,6 +211,14 @@ function M.isptr(v, noarray) ...@@ -180,6 +211,14 @@ function M.isptr(v, noarray)
return ctname return ctname
end end
-- Return true if variable is a non-nil constant that can be used as immediate value
-- e.g. result of KSHORT and KNUM
function M.isimmconst(v)
return (type(v.const) == 'number' and not ffi.istype(v.type, ffi.typeof('void')))
or type(v.const) == 'cdata' and ffi.istype(v.type, ffi.typeof('uint64_t')) -- Lua numbers are at most 52 bits
or type(v.const) == 'cdata' and ffi.istype(v.type, ffi.typeof('int64_t'))
end
function M.osversion() function M.osversion()
-- We have no better way to extract current kernel hex-string other -- We have no better way to extract current kernel hex-string other
-- than parsing headers, compiling a helper function or reading /proc -- than parsing headers, compiling a helper function or reading /proc
...@@ -203,10 +242,10 @@ function M.event_reader(reader, event_type) ...@@ -203,10 +242,10 @@ function M.event_reader(reader, event_type)
end end
-- Wrap reader in interface that can interpret read event messages -- Wrap reader in interface that can interpret read event messages
return setmetatable({reader=reader,type=event_type}, {__index = { return setmetatable({reader=reader,type=event_type}, {__index = {
block = function(self) block = function(_ --[[self]])
return S.select { readfds = {reader.fd} } return S.select { readfds = {reader.fd} }
end, end,
next = function(self, k) next = function(_ --[[self]], k)
local len, ev = reader:next(k) local len, ev = reader:next(k)
-- Filter out only sample frames -- Filter out only sample frames
while ev and ev.type ~= S.c.PERF_RECORD.SAMPLE do while ev and ev.type ~= S.c.PERF_RECORD.SAMPLE do
......
...@@ -37,6 +37,8 @@ local function decode_ins(func, pc) ...@@ -37,6 +37,8 @@ local function decode_ins(func, pc)
end end
if mc == 13*128 then -- BCMjump if mc == 13*128 then -- BCMjump
c = pc+d-0x7fff c = pc+d-0x7fff
elseif mc == 14*128 then -- BCMcdata
c = jutil.funck(func, -d-1)
elseif mc == 9*128 then -- BCMint elseif mc == 9*128 then -- BCMint
c = jutil.funck(func, d) c = jutil.funck(func, d)
elseif mc == 10*128 then -- BCMstr elseif mc == 10*128 then -- BCMstr
......
...@@ -33,6 +33,21 @@ struct sk_buff { ...@@ -33,6 +33,21 @@ struct sk_buff {
uint32_t cb[5]; uint32_t cb[5];
uint32_t hash; uint32_t hash;
uint32_t tc_classid; uint32_t tc_classid;
uint32_t data;
uint32_t data_end;
uint32_t napi_id;
/* Accessed by BPF_PROG_TYPE_sk_skb types from here to ... */
uint32_t family;
uint32_t remote_ip4; /* Stored in network byte order */
uint32_t local_ip4; /* Stored in network byte order */
uint32_t remote_ip6[4]; /* Stored in network byte order */
uint32_t local_ip6[4]; /* Stored in network byte order */
uint32_t remote_port; /* Stored in network byte order */
uint32_t local_port; /* stored in host byte order */
/* ... here. */
uint32_t data_meta;
}; };
struct net_off_t { struct net_off_t {
...@@ -185,7 +200,7 @@ else ...@@ -185,7 +200,7 @@ else
end end
-- Map symbolic registers to architecture ABI -- Map symbolic registers to architecture ABI
ffi.metatype('struct pt_regs', { ffi.metatype('struct pt_regs', {
__index = function (t,k) __index = function (_ --[[t]],k)
return assert(parm_to_reg[k], 'no such register: '..k) return assert(parm_to_reg[k], 'no such register: '..k)
end, end,
}) })
...@@ -223,7 +238,7 @@ local function next_offset(e, var, type, off, mask, shift) ...@@ -223,7 +238,7 @@ local function next_offset(e, var, type, off, mask, shift)
if mask then if mask then
e.emit(BPF.ALU + BPF.AND + BPF.K, tmp_reg, 0, 0, mask) e.emit(BPF.ALU + BPF.AND + BPF.K, tmp_reg, 0, 0, mask)
end end
if shift then if shift and shift ~= 0 then
local op = BPF.LSH local op = BPF.LSH
if shift < 0 then if shift < 0 then
op = BPF.RSH op = BPF.RSH
...@@ -264,9 +279,9 @@ M.type = function(typestr, t) ...@@ -264,9 +279,9 @@ M.type = function(typestr, t)
t.__dissector=ffi.typeof(typestr) t.__dissector=ffi.typeof(typestr)
return t return t
end end
M.skb = M.type('struct sk_buff', {__base=true}) M.skb = M.type('struct sk_buff', {source='ptr_to_ctx'})
M.pt_regs = M.type('struct pt_regs', {__base=true, source='probe'}) M.pt_regs = M.type('struct pt_regs', {source='ptr_to_probe'})
M.pkt = {off=0, __dissector=ffi.typeof('struct eth_t')} -- skb needs special accessors M.pkt = M.type('struct eth_t', {off=0, source='ptr_to_pkt'}) -- skb needs special accessors
-- M.eth = function (...) return dissector(ffi.typeof('struct eth_t'), ...) end -- M.eth = function (...) return dissector(ffi.typeof('struct eth_t'), ...) end
M.dot1q = function (...) return dissector(ffi.typeof('struct dot1q_t'), ...) end M.dot1q = function (...) return dissector(ffi.typeof('struct dot1q_t'), ...) end
M.arp = function (...) return dissector(ffi.typeof('struct arp_t'), ...) end M.arp = function (...) return dissector(ffi.typeof('struct arp_t'), ...) end
...@@ -310,6 +325,28 @@ ffi.metatype(ffi.typeof('struct ip_t'), { ...@@ -310,6 +325,28 @@ ffi.metatype(ffi.typeof('struct ip_t'), {
} }
}) })
ffi.metatype(ffi.typeof('struct ip6_t'), {
__index = {
-- Skip fixed IPv6 header length (40 bytes)
-- The caller must check the value of `next_header` to skip any extension headers
icmp6 = function(e, dst) next_skip(e, dst, ffi.sizeof('struct ip6_t'), 0) end,
udp = function(e, dst) next_skip(e, dst, ffi.sizeof('struct ip6_t'), 0) end,
tcp = function(e, dst) next_skip(e, dst, ffi.sizeof('struct ip6_t'), 0) end,
ip6_opt = function(e, dst) next_skip(e, dst, ffi.sizeof('struct ip6_t'), 0) end,
}
})
local ip6_opt_ext_len_off = ffi.offsetof('struct ip6_opt_t', 'ext_len')
ffi.metatype(ffi.typeof('struct ip6_opt_t'), {
__index = {
-- Skip IPv6 extension header length (field `ext_len`)
icmp6 = function(e, dst) next_offset(e, dst, ffi.typeof('uint8_t'), ip6_opt_ext_len_off) end,
udp = function(e, dst) next_offset(e, dst, ffi.typeof('uint8_t'), ip6_opt_ext_len_off) end,
tcp = function(e, dst) next_offset(e, dst, ffi.typeof('uint8_t'), ip6_opt_ext_len_off) end,
ip6_opt = function(e, dst) next_offset(e, dst, ffi.typeof('uint8_t'), ip6_opt_ext_len_off) end,
}
})
ffi.metatype(ffi.typeof('struct tcp_t'), { ffi.metatype(ffi.typeof('struct tcp_t'), {
__index = { __index = {
-- Skip TCP header length (stored as number of words) -- Skip TCP header length (stored as number of words)
......
local ffi = require('ffi')
local S = require('syscall')
-- Normalize whitespace and remove empty lines
local function normalize_code(c)
local res = {}
for line in string.gmatch(c,'[^\r\n]+') do
local op, d, s, t = line:match('(%S+)%s+(%S+)%s+(%S+)%s*([^-]*)')
if op then
t = t and t:match('^%s*(.-)%s*$')
table.insert(res, string.format('%s\t%s %s %s', op, d, s, t))
end
end
return table.concat(res, '\n')
end
-- Compile code and check result
local function compile(t)
local bpf = require('bpf')
-- require('jit.bc').dump(t.input)
local code, err = bpf(t.input)
assert.truthy(code)
assert.falsy(err)
if code then
if t.expect then
local got = normalize_code(bpf.dump_string(code, 1, true))
-- if normalize_code(t.expect) ~= got then print(bpf.dump_string(code, 1)) end
assert.same(normalize_code(t.expect), got)
end
end
end
-- Make a mock map variable
local function makemap(type, max_entries, key_ctype, val_ctype)
if not key_ctype then key_ctype = ffi.typeof('uint32_t') end
if not val_ctype then val_ctype = ffi.typeof('uint32_t') end
if not max_entries then max_entries = 4096 end
return {
__map = true,
max_entries = max_entries,
key = ffi.new(ffi.typeof('$ [1]', key_ctype)),
val = ffi.new(ffi.typeof('$ [1]', val_ctype)),
map_type = S.c.BPF_MAP[type],
key_type = key_ctype,
val_type = val_ctype,
fd = 42,
}
end
describe('codegen', function()
-- luacheck: ignore 113 211 212 311 511
describe('constants', function()
it('remove dead constant store', function()
compile {
input = function ()
local proto = 5
end,
expect = [[
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('materialize constant', function()
compile {
input = function ()
return 5
end,
expect = [[
MOV R0 #5
EXIT R0 #0
]]
}
end)
it('materialize constant longer than i32', function()
compile {
input = function ()
return 4294967295
end,
expect = [[
LDDW R0 #4294967295
EXIT R0 #0
]]
}
end)
it('materialize cdata constant', function()
compile {
input = function ()
return 5ULL
end,
expect = [[
LDDW R0 #5 -- composed instruction
EXIT R0 #0
]]
}
end)
it('materialize signed cdata constant', function()
compile {
input = function ()
return 5LL
end,
expect = [[
LDDW R0 #5 -- composed instruction
EXIT R0 #0
]]
}
end)
it('materialize coercible numeric cdata constant', function()
compile {
input = function ()
return 0x00005
end,
expect = [[
MOV R0 #5
EXIT R0 #0
]]
}
end)
it('materialize constant through variable', function()
compile {
input = function ()
local proto = 5
return proto
end,
expect = [[
MOV R0 #5
EXIT R0 #0
]]
}
end)
it('eliminate constant expressions', function()
compile {
input = function ()
return 2 + 3 - 0
end,
expect = [[
MOV R0 #5
EXIT R0 #0
]]
}
end)
it('eliminate constant expressions (if block)', function()
compile {
input = function ()
local proto = 5
if proto == 5 then
proto = 1
end
return proto
end,
expect = [[
MOV R0 #1
EXIT R0 #0
]]
}
end)
it('eliminate negative constant expressions (if block) NYI', function()
-- always negative condition is not fully eliminated
compile {
input = function ()
local proto = 5
if false then
proto = 1
end
return proto
end,
expect = [[
MOV R7 #5
STXDW [R10-8] R7
MOV R7 #0
JEQ R7 #0 => 0005
LDXDW R0 [R10-8]
EXIT R0 #0
]]
}
end)
end)
describe('variables', function()
it('classic packet access (fold constant offset)', function()
compile {
input = function (skb)
return eth.ip.tos -- constant expression will fold
end,
expect = [[
LDB R0 skb[15]
EXIT R0 #0
]]
}
end)
it('classic packet access (load non-constant offset)', function()
compile {
input = function (skb)
return eth.ip.udp.src_port -- need to skip variable-length header
end,
expect = [[
LDB R0 skb[14]
AND R0 #15
LSH R0 #2
ADD R0 #14
STXDW [R10-16] R0 -- NYI: erase dead store
LDH R0 skb[R0+0]
END R0 R0
EXIT R0 #0
]]
}
end)
it('classic packet access (manipulate dissector offset)', function()
compile {
input = function (skb)
local ptr = eth.ip.udp.data + 1
return ptr[0] -- dereference dissector pointer
end,
expect = [[
LDB R0 skb[14]
AND R0 #15
LSH R0 #2
ADD R0 #14 -- NYI: fuse commutative operations in second pass
ADD R0 #8
ADD R0 #1
STXDW [R10-16] R0
LDB R0 skb[R0+0]
EXIT R0 #0
]]
}
end)
it('classic packet access (multi-byte load)', function()
compile {
input = function (skb)
local ptr = eth.ip.udp.data
return ptr(1, 5) -- load 4 bytes
end,
expect = [[
LDB R0 skb[14]
AND R0 #15
LSH R0 #2
ADD R0 #14
ADD R0 #8
MOV R7 R0
STXDW [R10-16] R0 -- NYI: erase dead store
LDW R0 skb[R7+1]
END R0 R0
EXIT R0 #0
]]
}
end)
it('direct skb field access', function()
compile {
input = function (skb)
return skb.len
end,
expect = [[
LDXW R7 [R6+0]
MOV R0 R7
EXIT R0 #0
]]
}
end)
it('direct skb data access (manipulate offset)', function()
compile {
input = function (skb)
local ptr = skb.data + 5
return ptr[0]
end,
expect = [[
LDXW R7 [R6+76]
ADD R7 #5
LDXB R8 [R7+0] -- NYI: transform LD + ADD to LD + offset addressing
MOV R0 R8
EXIT R0 #0
]]
}
end)
it('direct skb data access (offset boundary check)', function()
compile {
input = function (skb)
local ptr = skb.data + 5
if ptr < skb.data_end then
return ptr[0]
end
end,
expect = [[
LDXW R7 [R6+76]
ADD R7 #5
LDXW R8 [R6+80]
JGE R7 R8 => 0008
LDXB R8 [R7+0]
MOV R0 R8
EXIT R0 #0
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('access stack memory (array, const load, const store)', function()
compile {
input = function (skb)
local mem = ffi.new('uint8_t [16]')
mem[0] = 5
end,
expect = [[
MOV R0 #0
STXDW [R10-40] R0
STXDW [R10-48] R0 -- NYI: erase zero-fill on allocation when it's loaded later
STB [R10-48] #5
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('access stack memory (array, const load, packet store)', function()
compile {
input = function (skb)
local mem = ffi.new('uint8_t [7]')
mem[0] = eth.ip.tos
end,
expect = [[
MOV R0 #0
STXDW [R10-40] R0 -- NYI: erase zero-fill on allocation when it's loaded later
LDB R0 skb[15]
STXB [R10-40] R0
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('access stack memory (array, packet load, const store)', function()
compile {
input = function (skb)
local mem = ffi.new('uint8_t [1]')
mem[eth.ip.tos] = 5
end,
expect = [[
MOV R0 #0
STXDW [R10-48] R0 -- NYI: erase zero-fill on allocation when it's loaded later
LDB R0 skb[15]
MOV R7 R0
ADD R7 R10
STB [R7-48] #5
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('access stack memory (array, packet load, packet store)', function()
compile {
input = function (skb)
local mem = ffi.new('uint8_t [7]')
local v = eth.ip.tos
mem[v] = v
end,
expect = [[
MOV R0 #0
STXDW [R10-40] R0 -- NYI: erase zero-fill on allocation when it's loaded later
LDB R0 skb[15]
MOV R7 R0
ADD R7 R10
STXB [R7-40] R0
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('access stack memory (struct, const/packet store)', function()
local kv_t = 'struct { uint64_t a; uint64_t b; }'
compile {
input = function (skb)
local mem = ffi.new(kv_t)
mem.a = 5
mem.b = eth.ip.tos
end,
expect = [[
MOV R0 #0
STXDW [R10-40] R0
STXDW [R10-48] R0 -- NYI: erase zero-fill on allocation when it's loaded later
MOV R7 #5
STXDW [R10-48] R7
LDB R0 skb[15]
STXDW [R10-40] R0
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('access stack memory (struct, const/stack store)', function()
local kv_t = 'struct { uint64_t a; uint64_t b; }'
compile {
input = function (skb)
local m1 = ffi.new(kv_t)
local m2 = ffi.new(kv_t)
m1.a = 5
m2.b = m1.a
end,
expect = [[
MOV R0 #0
STXDW [R10-48] R0
STXDW [R10-56] R0 -- NYI: erase zero-fill on allocation when it's loaded later
MOV R0 #0
STXDW [R10-64] R0
STXDW [R10-72] R0 -- NYI: erase zero-fill on allocation when it's loaded later
MOV R7 #5
STXDW [R10-56] R7
LDXDW R7 [R10-56]
STXDW [R10-64] R7
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (u32, const key load)', function()
local array_map = makemap('array', 256)
compile {
input = function (skb)
return array_map[0]
end,
expect = [[
LDDW R1 #42
STW [R10-28] #0
MOV R2 R10
ADD R2 #4294967268
CALL R0 #1 ; map_lookup_elem
JEQ R0 #0 => 0009
LDXW R0 [R0+0]
EXIT R0 #0
]]
}
end)
it('array map (u32, packet key load)', function()
local array_map = makemap('array', 256)
compile {
input = function (skb)
return array_map[eth.ip.tos]
end,
expect = [[
LDB R0 skb[15]
LDDW R1 #42
STXW [R10-36] R0
MOV R2 R10
ADD R2 #4294967260
STXDW [R10-24] R0 -- NYI: erase dead store
CALL R0 #1 ; map_lookup_elem
JEQ R0 #0 => 0011
LDXW R0 [R0+0]
EXIT R0 #0
]]
}
end)
it('array map (u32, const key store, const value)', function()
local array_map = makemap('array', 256)
compile {
input = function (skb)
array_map[0] = 5
end,
expect = [[
LDDW R1 #42
STW [R10-36] #0
MOV R2 R10
ADD R2 #4294967260
MOV R4 #0
STW [R10-40] #5
MOV R3 R10
ADD R3 #4294967256
CALL R0 #2 ; map_update_elem
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (u32, const key store, packet value)', function()
local array_map = makemap('array', 256)
compile {
input = function (skb)
array_map[0] = eth.ip.tos
end,
expect = [[
LDB R0 skb[15]
STXDW [R10-24] R0
LDDW R1 #42
STW [R10-36] #0
MOV R2 R10
ADD R2 #4294967260
MOV R4 #0
MOV R3 R10
ADD R3 #4294967272
CALL R0 #2 ; map_update_elem
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (u32, const key store, map value)', function()
local array_map = makemap('array', 256)
compile {
input = function (skb)
array_map[0] = array_map[1]
end,
expect = [[
LDDW R1 #42
STW [R10-36] #1
MOV R2 R10
ADD R2 #4294967260
CALL R0 #1 ; map_lookup_elem
STXDW [R10-24] R0
LDDW R1 #42
STW [R10-36] #0
MOV R2 R10
ADD R2 #4294967260
MOV R4 #0
LDXDW R3 [R10-24]
JEQ R3 #0 => 0017
LDXW R3 [R3+0]
STXW [R10-40] R3
MOV R3 R10
ADD R3 #4294967256
CALL R0 #2 ; map_update_elem
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (u32, const key replace, const value)', function()
local array_map = makemap('array', 256)
compile {
input = function (skb)
local val = array_map[0]
if val then
val[0] = val[0] + 1
else
array_map[0] = 5
end
end,
expect = [[
LDDW R1 #42
STW [R10-44] #0
MOV R2 R10
ADD R2 #4294967252
CALL R0 #1 ; map_lookup_elem
JEQ R0 #0 => 0013 -- if (map_value ~= NULL)
LDXW R7 [R0+0]
ADD R7 #1
STXW [R0+0] R7
MOV R7 #0
JEQ R7 #0 => 0025 -- skip false branch
STXDW [R10-16] R0
LDDW R1 #42
STW [R10-44] #0
MOV R2 R10
ADD R2 #4294967252
MOV R4 #0
STW [R10-48] #5
MOV R3 R10
ADD R3 #4294967248
CALL R0 #2 ; map_update_elem
LDXDW R0 [R10-16]
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (u32, const key replace xadd, const value)', function()
local array_map = makemap('array', 256)
compile {
input = function (skb)
local val = array_map[0]
if val then
xadd(val, 1)
else
array_map[0] = 5
end
end,
expect = [[
LDDW R1 #42
STW [R10-52] #0
MOV R2 R10
ADD R2 #4294967244
CALL R0 #1 ; map_lookup_elem
JEQ R0 #0 => 0014 -- if (map_value ~= NULL)
MOV R7 #1
MOV R8 R0
STXDW [R10-16] R0
XADDW [R8+0] R7
MOV R7 #0
JEQ R7 #0 => 0025 -- skip false branch
STXDW [R10-16] R0
LDDW R1 #42
STW [R10-52] #0
MOV R2 R10
ADD R2 #4294967244
MOV R4 #0
STW [R10-56] #5
MOV R3 R10
ADD R3 #4294967240
CALL R0 #2 ; map_update_elem
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (u32, const key replace xadd, const value) inverse nil check', function()
local array_map = makemap('array', 256)
compile {
input = function (skb)
local val = array_map[0]
if not val then
array_map[0] = 5
else
xadd(val, 1)
end
end,
expect = [[
LDDW R1 #42
STW [R10-52] #0
MOV R2 R10
ADD R2 #4294967244
CALL R0 #1 ; map_lookup_elem
JNE R0 #0 => 0021
STXDW [R10-16] R0
LDDW R1 #42
STW [R10-52] #0
MOV R2 R10
ADD R2 #4294967244
MOV R4 #0
STW [R10-56] #5
MOV R3 R10
ADD R3 #4294967240
CALL R0 #2 ; map_update_elem
MOV R7 #0
JEQ R7 #0 => 0025
MOV R7 #1
MOV R8 R0
STXDW [R10-16] R0
XADDW [R8+0] R7
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (struct, stack key load)', function()
local kv_t = 'struct { uint64_t a; uint64_t b; }'
local array_map = makemap('array', 256, ffi.typeof(kv_t), ffi.typeof(kv_t))
compile {
input = function (skb)
local key = ffi.new(kv_t)
key.a = 2
key.b = 3
local val = array_map[key] -- Use composite key from stack memory
if val then
return val.a
end
end,
expect = [[
MOV R0 #0
STXDW [R10-48] R0
STXDW [R10-56] R0 -- NYI: erase zero-fill on allocation when it's loaded later
MOV R7 #2
STXDW [R10-56] R7
MOV R7 #3
STXDW [R10-48] R7
LDDW R1 #42
MOV R2 R10
ADD R2 #4294967240
CALL R0 #1 ; map_lookup_elem
JEQ R0 #0 => 0017
LDXDW R7 [R0+0]
MOV R0 R7
EXIT R0 #0
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (struct, stack key store)', function()
local kv_t = 'struct { uint64_t a; uint64_t b; }'
local array_map = makemap('array', 256, ffi.typeof(kv_t), ffi.typeof(kv_t))
compile {
input = function (skb)
local key = ffi.new(kv_t)
key.a = 2
key.b = 3
array_map[key] = key -- Use composite key from stack memory
end,
expect = [[
MOV R0 #0
STXDW [R10-40] R0
STXDW [R10-48] R0 -- NYI: erase zero-fill on allocation when it's loaded later
MOV R7 #2
STXDW [R10-48] R7
MOV R7 #3
STXDW [R10-40] R7
LDDW R1 #42
MOV R2 R10
ADD R2 #4294967248
MOV R4 #0
MOV R3 R10
ADD R3 #4294967248
CALL R0 #2 ; map_update_elem
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (struct, stack/packet key update, const value)', function()
local kv_t = 'struct { uint64_t a; uint64_t b; }'
local array_map = makemap('array', 256, ffi.typeof(kv_t), ffi.typeof(kv_t))
compile {
input = function (skb)
local key = ffi.new(kv_t)
key.a = eth.ip.tos -- Load key part from dissector
local val = array_map[key]
if val then
val.a = 5
end
end,
expect = [[
MOV R0 #0
STXDW [R10-48] R0
STXDW [R10-56] R0 -- NYI: erase zero-fill on allocation when it's loaded later
LDB R0 skb[15]
STXDW [R10-56] R0
LDDW R1 #42
MOV R2 R10
ADD R2 #4294967240
CALL R0 #1 ; map_lookup_elem
JEQ R0 #0 => 0014
MOV R7 #5
STXDW [R0+0] R7
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (struct, stack/packet key update, map value)', function()
local kv_t = 'struct { uint64_t a; uint64_t b; }'
local array_map = makemap('array', 256, ffi.typeof(kv_t), ffi.typeof(kv_t))
compile {
input = function (skb)
local key = ffi.new(kv_t)
key.a = eth.ip.tos -- Load key part from dissector
local val = array_map[key]
if val then
val.a = val.b
end
end,
expect = [[
MOV R0 #0
STXDW [R10-48] R0
STXDW [R10-56] R0 -- NYI: erase zero-fill on allocation when it's loaded later
LDB R0 skb[15]
STXDW [R10-56] R0
LDDW R1 #42
MOV R2 R10
ADD R2 #4294967240
CALL R0 #1 ; map_lookup_elem
JEQ R0 #0 => 0014
LDXDW R7 [R0+8]
STXDW [R0+0] R7
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (struct, stack/packet key update, stack value)', function()
local kv_t = 'struct { uint64_t a; uint64_t b; }'
local array_map = makemap('array', 256, ffi.typeof(kv_t), ffi.typeof(kv_t))
compile {
input = function (skb)
local key = ffi.new(kv_t)
key.a = eth.ip.tos -- Load key part from dissector
local val = array_map[key]
if val then
val.a = key.b
end
end,
expect = [[
MOV R0 #0
STXDW [R10-48] R0
STXDW [R10-56] R0 -- NYI: erase zero-fill on allocation when it's loaded later
LDB R0 skb[15]
STXDW [R10-56] R0
LDDW R1 #42
MOV R2 R10
ADD R2 #4294967240
CALL R0 #1 ; map_lookup_elem
JEQ R0 #0 => 0014
LDXDW R7 [R10-48]
STXDW [R0+0] R7
MOV R0 #0
EXIT R0 #0
]]
}
end)
it('array map (struct, stack/packet key replace, stack value)', function()
local kv_t = 'struct { uint64_t a; uint64_t b; }'
local array_map = makemap('array', 256, ffi.typeof(kv_t), ffi.typeof(kv_t))
compile {
input = function (skb)
local key = ffi.new(kv_t)
key.a = eth.ip.tos -- Load key part from dissector
local val = array_map[key]
if val then
val.a = key.b
else
array_map[key] = key
end
end,
expect = [[
MOV R0 #0
STXDW [R10-48] R0
STXDW [R10-56] R0
LDB R0 skb[15]
STXDW [R10-56] R0
LDDW R1 #42
MOV R2 R10
ADD R2 #4294967240
CALL R0 #1 ; map_lookup_elem
JEQ R0 #0 => 0016 -- if (map_value ~= NULL)
LDXDW R7 [R10-48]
STXDW [R0+0] R7
MOV R7 #0
JEQ R7 #0 => 0026 -- jump over false branch
STXDW [R10-24] R0
LDDW R1 #42
MOV R2 R10
ADD R2 #4294967240
MOV R4 #0
MOV R3 R10
ADD R3 #4294967240
CALL R0 #2 ; map_update_elem
LDXDW R0 [R10-24]
MOV R0 #0
EXIT R0 #0
]]
}
end)
end)
describe('control flow', function()
it('condition with constant return', function()
compile {
input = function (skb)
local v = eth.ip.tos
if v then
return 1
else
return 0
end
end,
expect = [[
LDB R0 skb[15]
JEQ R0 #0 => 0005
MOV R0 #1
EXIT R0 #0
MOV R0 #0 -- 0005 jump target
EXIT R0 #0
]]
}
end)
it('condition with cdata constant return', function()
local cdata = 2ULL
compile {
input = function (skb)
local v = eth.ip.tos
if v then
return cdata + 1
else
return 0
end
end,
expect = [[
LDB R0 skb[15]
JEQ R0 #0 => 0006
LDDW R0 #3
EXIT R0 #0
MOV R0 #0 -- 0006 jump target
EXIT R0 #0
]]
}
end)
it('condition with constant return (inversed)', function()
compile {
input = function (skb)
local v = eth.ip.tos
if not v then
return 1
else
return 0
end
end,
expect = [[
LDB R0 skb[15]
JNE R0 #0 => 0005
MOV R0 #1
EXIT R0 #0
MOV R0 #0 -- 0005 jump target
EXIT R0 #0
]]
}
end)
it('condition with variable mutation', function()
compile {
input = function (skb)
local v = 0
if eth.ip.tos then
v = 1
end
return v
end,
expect = [[
LDB R0 skb[15]
MOV R1 #0
STXDW [R10-16] R1
JEQ R0 #0 => 0007
MOV R7 #1
STXDW [R10-16] R7
LDXDW R0 [R10-16]
EXIT R0 #0
]]
}
end)
it('condition with nil variable mutation', function()
compile {
input = function (skb)
local v -- nil, will be elided
if eth.ip.tos then
v = 1
else
v = 0
end
return v
end,
expect = [[
LDB R0 skb[15]
JEQ R0 #0 => 0007
MOV R7 #1
STXDW [R10-16] R7
MOV R7 #0
JEQ R7 #0 => 0009
MOV R7 #0
STXDW [R10-16] R7
LDXDW R0 [R10-16]
EXIT R0 #0
]]
}
end)
it('nested condition with variable mutation', function()
compile {
input = function (skb)
local v = 0
local tos = eth.ip.tos
if tos then
if tos > 5 then
v = 5
else
v = 1
end
end
return v
end,
expect = [[
LDB R0 skb[15]
MOV R1 #0
STXDW [R10-16] R1 -- materialize v = 0
JEQ R0 #0 => 0013 -- if not tos
MOV R7 #5
JGE R7 R0 => 0011 -- if 5 > tos
MOV R7 #5
STXDW [R10-16] R7 -- materialize v = 5
MOV R7 #0
JEQ R7 #0 => 0013
MOV R7 #1 -- 0011 jump target
STXDW [R10-16] R7 -- materialize v = 1
LDXDW R0 [R10-16]
EXIT R0 #0
]]
}
end)
it('nested condition with variable shadowing', function()
compile {
input = function (skb)
local v = 0
local tos = eth.ip.tos
if tos then
local v = 0 -- luacheck: ignore 231
if tos > 5 then
v = 5 -- changing shadowing variable
end
else
v = 1
end
return v
end,
expect = [[
LDB R0 skb[15]
MOV R1 #0
STXDW [R10-16] R1 -- materialize v = 0
JEQ R0 #0 => 0011 -- if not tos
MOV R7 #5
MOV R1 #0
STXDW [R10-32] R1 -- materialize shadowing variable
JGE R7 R0 => 0013 -- if 5 > tos
MOV R7 #0 -- erased 'v = 5' dead store
JEQ R7 #0 => 0013
MOV R7 #1 -- 0011 jump target
STXDW [R10-16] R7 -- materialize v = 1
LDXDW R0 [R10-16] -- 0013 jump target
EXIT R0 #0
]]
}
end)
it('condition materializes shadowing variable at the end of BB', function()
compile {
input = function (skb)
local v = time()
local v1 = 0 -- luacheck: ignore 231
if eth.ip.tos then
v1 = v
end
end,
expect = [[
CALL R0 #5 ; ktime_get_ns
STXDW [R10-16] R0
LDB R0 skb[15]
MOV R1 #0
STXDW [R10-24] R1 -- materialize v1 = 0
JEQ R0 #0 => 0009
LDXDW R7 [R10-16]
STXDW [R10-24] R7 -- v1 = v0
MOV R0 #0
EXIT R0 #0
]]
}
end)
end)
end)
local ffi = require('ffi')
-- Define basic ctypes
ffi.cdef [[
struct bpf_insn {
uint8_t code; /* opcode */
uint8_t dst_reg:4; /* dest register */
uint8_t src_reg:4; /* source register */
uint16_t off; /* signed offset */
uint32_t imm; /* signed immediate constant */
};
]]
-- Inject mock ljsyscall for tests
package.loaded['syscall'] = {
bpf = function() error('mock') end,
c = { BPF_MAP = {}, BPF_PROG = {} },
abi = { arch = 'x64' },
}
package.loaded['syscall.helpers'] = {
strflag = function (tab)
local function flag(cache, str)
if type(str) ~= "string" then return str end
if #str == 0 then return 0 end
local s = str:upper()
if #s == 0 then return 0 end
local val = rawget(tab, s)
if not val then return nil end
cache[str] = val
return val
end
return setmetatable(tab, {__index = setmetatable({}, {__index = flag}), __call = function(t, a) return t[a] end})
end
}
\ No newline at end of file
src/lua/bpf/spec
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment