Skip to content

Commit

Permalink
apps.xdp.pf_ebpf_codegen: add eBPF backend for pflua
Browse files Browse the repository at this point in the history
  • Loading branch information
eugeneia committed Feb 7, 2020
1 parent cd4d773 commit 48ff71c
Showing 1 changed file with 373 additions and 0 deletions.
373 changes: 373 additions & 0 deletions src/apps/xdp/pf_ebpf_codegen.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,373 @@
-- Use of this source code is governed by the Apache 2.0 license; see COPYING.

module(...,package.seeall)

-- This module implements code generation for the XDP/eBPF backend of
-- Pflua. It takes the result of instruction selection (selection.lua)
-- and register allocation (regalloc.lua) and generates a function with
-- eBPF bytecode.

local parse = require('pf.parse').parse
local expand = require('pf.expand').expand
local optimize = require('pf.optimize').optimize
local anf = require('pf.anf')
local ssa = require('pf.ssa')
local sel = require("pf.selection")
local ra = require("pf.regalloc")
local bpf = require("apps.xdp.bpf")

local c, f, m, a, s, j = bpf.c, bpf.f, bpf.m, bpf.a, bpf.s, bpf.j

local tobit, band, bor, rshift = bit.tobit, bit.band, bit.bor, bit.rshift

-- eBPF register allocation:
-- * mark r1 callee save: holds the xdp_md context we wish to preserve
-- * omit r0: we will keep a pointer to the packet payload in here
-- * omit r2: we will use this register to perform length checks
-- * use r3 as len: we will store data_end here (used in length checks)
local ebpf_regs = {
caller_regs = { 9, 8, 7, 6, 5, 4, 3 },
callee_regs = { 1 },
len = 3
}

-- Generate a eBPF XDP program that will return XDP_PASS unless filter expr
-- matches, and otherwise "fall-though" as to allow execution of a further eBPF
-- program that is to be appended.
function codegen (ir, alloc)
-- push callee-save registers if we use any
local to_pop = {}
for reg, _ in pairs(alloc.callee_saves) do
error("NYI: callee saves")
-- we need to record the order in which to pop
-- b/c while the push order doesn't matter, the
-- pop order must be reverse (and callee_saves
-- is an unordered set)
table.insert(to_pop, reg)
end

-- in bytes
local stack_slot_size = 8

-- allocate space for all spilled vars
local spilled_space = 0
for _, _ in pairs(alloc.spills) do
spilled_space = spilled_space + stack_slot_size
end
if spilled_space > 0 then
error("NYI: spilled space")
end

-- if the length variable got spilled, we need to explicitly initialize
-- the stack slot for it
if alloc.spills["len"] then
error("NYI: spilled length")
end

local pc, tr = 1, {}
local function emit (ins)
tr[pc] = ins
pc = pc+1
end

local label_offset, labels = 2, {}

local cmp
local function emit_cjmp (cond, target)
assert(cmp, "cjmp needs preceeding cmp")
local jmp = cmp; cmp = nil
jmp.op = bor(c.JMP, cond, jmp.op)
if target == "true-label" then
jmp.off = 0
elseif target == "false-label" then
jmp.off = 1
else
jmp.off = label_offset+target
end
emit(jmp)
end

-- Setup: move data start and end pointers into r0 and r(alloc.len)
-- r0 = ((struct xdp_md *)ctx)->data
emit{ op=bor(c.LDX, f.W, m.MEM), dst=0, src=1, off=0 }
-- r(alloc.len) = ((struct xdp_md *)ctx)->data_end
emit{ op=bor(c.LDX, f.W, m.MEM), dst=alloc.len, src=1, off=4 }

for idx, instr in ipairs(ir) do
local itype = instr[1]

--- FIXME: handle spills

-- the core code generation logic starts here
if itype == "label" then
local lnum = instr[2]
labels[label_offset+lnum] = pc

elseif itype == "cjmp" then
local op, target = instr[2], instr[3]

if op == "=" then
emit_cjmp(j.JEQ, target)
elseif op == "!=" then
emit_cjmp(j.JNE, target)
elseif op == ">=" then
emit_cjmp(j.JGE, target)
elseif op == "<=" then
emit_cjmp(j.JLE, target)
elseif op == ">" then
emit_cjmp(j.JGT, target)
elseif op == "<" then
emit_cjmp(j.JLT, target)
end

elseif itype == "jmp" then
local next_instr = ir[idx+1]
-- if the jump target is immediately after this in the instruction
-- sequence then don't generate the jump
if (type(instr[2]) == "number" and
next_instr[1] == "label" and
next_instr[2] == instr[2]) then
-- don't output anything
else
if instr[2] == "true-label" then
if next_instr[1] ~= "ret-true" then
emit{ op=bor(c.JMP, j.JA), off=0 }
end
elseif instr[2] == "false-label" then
if next_instr[1] ~= "ret-false" then
emit{ op=bor(c.JMP, j.JA), off=1 }
end
else
emit{ op=bor(c.JMP, j.JA), off=label_offset+instr[2] }
end
end

elseif itype == "cmp" and instr[2] == "len" then
local lhs_reg = alloc.len
local rhs = instr[3]
assert(rhs ~= "len", "NYI: cmp with rhs len")

-- Perform eBPF friendly length check.
-- mov r2, r0
emit{ op=bor(c.ALU64, a.MOV, s.X), dst=2, src=0 }
-- add r2, rhs
if type(rhs) == "number" then
emit{ op=bor(c.ALU64, a.ADD, s.K), dst=2, imm=rhs }
else
emit{ op=bor(c.ALU64, a.ADD, s.X), dst=2, src=alloc[rhs] }
end
-- cmp r6, r2
cmp = { op=s.X, dst=lhs_reg, src=2 }

elseif itype == "cmp" then
-- the lhs should never be an immediate so this should be non-nil
local lhs_reg = assert(alloc[instr[2]])
local rhs = instr[3]
assert(rhs ~= "len", "NYI: cmp with rhs len")

if type(rhs) == "number" then
cmp = { op=s.K, dst=lhs_reg, imm=rhs }
else
local rhs_reg = alloc[rhs]
cmp = { op=s.X, dst=lhs_reg, src=rhs_reg }
end

elseif itype == "load" then
local target = alloc[instr[2]]
assert(not alloc.spills[instr[2]], "NYI: load spill")
local offset = instr[3]
local bytes = instr[4]

if type(offset) == "number" then
if bytes == 1 then
emit{ op=bor(c.LDX, f.B, m.MEM), dst=target, off=offset }
elseif bytes == 2 then
emit{ op=bor(c.LDX, f.H, m.MEM), dst=target, off=offset }
else
emit{ op=bor(c.LDX, f.W, m.MEM), dst=target, off=offset }
end
else
local reg = alloc[offset]
assert(not alloc.spills[offset], "NYI: load spill")

emit{ op=bor(c.ALU64, a.ADD, s.X), dst=reg }
if bytes == 1 then
emit{ op=bor(c.LDX, f.B, m.MEM), dst=target, src=reg }
elseif bytes == 2 then
emit{ op=bor(c.LDX, f.H, m.MEM), dst=target, src=reg }
else
emit{ op=bor(c.LDX, f.W, m.MEM), dst=target, src=reg }
end
emit{ op=bor(c.ALU64, a.SUB, s.X), dst=reg }
end

elseif itype == "mov" then
local dst = alloc[instr[2]]
assert(not alloc.spills[instr[2]], "NYI: mov spill")
local arg = instr[3]

if type(arg) == "number" then
emit{ op=bor(c.ALU, a.MOV, s.K), dst=dst, imm=arg }
else
assert(not alloc.spills[arg], "NYI: mov spill")
emit{ op=bor(c.ALU64, a.MOV, s.X), dst=dst, src=alloc[arg] }
end

elseif itype == "mov64" then
local dst = alloc[instr[2]]
local imm = instr[3]
emit{ op=bor(c.LD, f.DW, m.IMM), dst=dst, src=s.K, imm=tobit(imm) }
emit{ imm=rshift(imm, 32) }

elseif itype == "add" then
local reg1, reg2 = alloc[instr[2]], alloc[instr[3]]
emit{ op=bor(c.ALU64, a.ADD, s.X), dst=reg1, src=reg2 }

elseif itype == "sub" then
local reg1, reg2 = alloc[instr[2]], alloc[instr[3]]
emit{ op=bor(c.ALU64, a.SUB, s.X), dst=reg1, src=reg2 }

elseif itype == "mul" then
local reg1, reg2 = alloc[instr[2]], alloc[instr[3]]
emit{ op=bor(c.ALU64, a.MUL, s.X), dst=reg1, src=reg2 }

-- For division we use floating point division to avoid having
-- to deal with the %eax register for the div instruction.
elseif itype == "div" then
local reg1, reg2 = alloc[instr[2]], alloc[instr[3]]
emit{ op=bor(c.ALU64, a.DIV, s.X), dst=reg1, src=reg2 }

elseif itype == "and" then
local reg1, reg2 = alloc[instr[2]], alloc[instr[3]]
emit{ op=bor(c.ALU64, a.AND, s.X), dst=reg1, src=reg2 }

elseif itype == "or" then
local reg1, reg2 = alloc[instr[2]], alloc[instr[3]]
emit{ op=bor(c.ALU64, a.OR, s.X), dst=reg1, src=reg2 }

elseif itype == "xor" then
local reg1, reg2 = alloc[instr[2]], alloc[instr[3]]
emit{ op=bor(c.ALU64, a.XOR, s.X), dst=reg1, src=reg2 }

elseif itype == "shl" then
local reg1, reg2 = alloc[instr[2]], alloc[instr[3]]
emit{ op=bor(c.ALU64, a.LSH, s.X), dst=reg1, src=reg2 }

elseif itype == "shr" then
local reg1, reg2 = alloc[instr[2]], alloc[instr[3]]
emit{ op=bor(c.ALU64, a.RSH, s.X), dst=reg1, src=reg2 }

elseif itype == "add-i" then
local reg = alloc[instr[2]]
emit{ op=bor(c.ALU64, a.ADD, s.K), dst=reg, imm=instr[3] }

elseif itype == "sub-i" then
local reg = alloc[instr[2]]
emit{ op=bor(c.ALU64, a.SUB, s.K), dst=reg, imm=instr[3] }

elseif itype == "mul-i" then
local r = alloc[instr[2]]
emit{ op=bor(c.ALU64, a.MUL, s.K), dst=reg, imm=instr[3] }

elseif itype == "and-i" then
local reg = alloc[instr[2]]
assert(type(reg) == "number")
assert(type(instr[3]) == "number")
emit{ op=bor(c.ALU64, a.AND, s.K), dst=reg, imm=instr[3] }

elseif itype == "or-i" then
local reg = alloc[instr[2]]
assert(type(reg) == "number")
assert(type(instr[3]) == "number")
emit{ op=bor(c.ALU64, a.OR, s.K), dst=reg, imm=instr[3] }

elseif itype == "xor-i" then
local reg = alloc[instr[2]]
assert(type(reg) == "number")
assert(type(instr[3]) == "number")
emit{ op=bor(c.ALU64, a.XOR, s.K), dst=reg, imm=instr[3] }

elseif itype == "shl-i" then
local reg = alloc[instr[2]]
emit{ op=bor(c.ALU64, a.LSH, s.K), dst=reg, imm=instr[3] }

elseif itype == "shr-i" then
local reg = alloc[instr[2]]
emit{ op=bor(c.ALU64, a.RSH, s.K), dst=reg, imm=instr[3] }

elseif itype == "ntohs" then
local reg = alloc[instr[2]]
emit{ op=bor(c.ALU, a.END, a.BE), dst=reg, imm=16 }

elseif itype == "ntohl" then
local reg = alloc[instr[2]]
emit{ op=bor(c.ALU, a.END, a.BE), dst=reg, imm=32 }

elseif itype == "uint32" then
local reg = alloc[instr[2]]
emit{ op=bor(c.ALU, a.AND, s.X), dst=reg, src=reg }

elseif itype == "ret-true" then
labels[0] = pc
-- In the end, we will turn this into a jump to the first instruction
-- beyond the end of the emitted sequence.
emit{ op=bor(c.JMP, j.JA) }

elseif itype == "ret-false" then
labels[1] = pc
-- r0 = XDP_PASS
emit{ op=bor(c.ALU, a.MOV, s.K), dst=0, imm=2 }
-- EXIT:
emit{ op=bor(c.JMP, j.EXIT) }

elseif itype == "nop" then
-- don't output anything

else
error(string.format("NYI instruction %s", itype))
end
end

-- Fixup true-label
local true_label = labels[0]
if true_label == #tr then
-- True-label is last instruction: remove its target instruction
tr[true_label] = nil
elseif true_label then
-- Set the jump offset to the first ins. beyond the emitted sequence
tr[true_label].off = #tr - true_label
end

-- Fixup jump offsets
for pc, ins in ipairs(tr) do
if band(ins.op, c.JMP) == c.JMP and ins.off then
ins.off = labels[ins.off] - (pc+1)
end
end

return tr
end

function compile(filter, dump)
local expr = optimize(expand(parse(filter), "EN10MB"))
local ssa = ssa.convert_ssa(anf.convert_anf(expr))
local ir = sel.select(ssa)
local alloc = ra.allocate(ir, ebpf_regs)
local code = codegen(ir, alloc)
if dump then
require("core.lib").print_object(alloc)
require("core.lib").print_object(ir)
print(filter)
bpf.dis(bpf.asm(code))
end
return code
end

function selftest()
compile("ip proto esp or ip proto 99 or arp", "dump")
compile("ip6[6] = 50 or ip6[6] = 99 or "..
"(ip6[6] = 58 and (ip6[40] = 135 or ip6[40] = 136))",
"dump")
compile("1 = 2",
"dump")
end

0 comments on commit 48ff71c

Please sign in to comment.