Skip to content

Commit

Permalink
compiler: memory usage optimization around br_table (#2251)
Browse files Browse the repository at this point in the history
This optimizes the memory usage during compilation for 
br_table instructions. As you can see in the bench results below,
for some cases where lots of br_tables exists (the case named `zz`),
the compilation uses 10% less allocations and 5% less memory, hence
the slightly faster compilation.

```
goos: darwin
goarch: arm64
pkg: github.com/tetratelabs/wazero
                      │  old.txt   │             new.txt              │
                      │   sec/op   │   sec/op    vs base              │
Compilation/wazero-10   2.015 ± 2%   1.993 ± 0%  -1.09% (p=0.002 n=6)
Compilation/zig-10      4.200 ± 0%   4.161 ± 1%  -0.93% (p=0.004 n=6)
Compilation/zz-10       18.70 ± 0%   18.57 ± 0%  -0.69% (p=0.002 n=6)
geomean                 5.409        5.360       -0.90%

                      │   old.txt    │              new.txt               │
                      │     B/op     │     B/op      vs base              │
Compilation/wazero-10   297.5Mi ± 0%   287.1Mi ± 0%  -3.48% (p=0.002 n=6)
Compilation/zig-10      593.9Mi ± 0%   590.3Mi ± 0%  -0.61% (p=0.002 n=6)
Compilation/zz-10       582.6Mi ± 0%   553.7Mi ± 0%  -4.96% (p=0.002 n=6)
geomean                 468.7Mi        454.4Mi       -3.03%

                      │   old.txt   │              new.txt               │
                      │  allocs/op  │  allocs/op   vs base               │
Compilation/wazero-10   457.0k ± 0%   449.1k ± 0%   -1.72% (p=0.002 n=6)
Compilation/zig-10      275.8k ± 0%   273.8k ± 0%   -0.70% (p=0.002 n=6)
Compilation/zz-10       926.5k ± 0%   830.9k ± 0%  -10.32% (p=0.002 n=6)
geomean                 488.7k        467.5k        -4.35%
```

Signed-off-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
  • Loading branch information
mathetake authored Jun 14, 2024
1 parent ec36887 commit b9571df
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 89 deletions.
3 changes: 2 additions & 1 deletion internal/engine/wazevo/backend/compiler_lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,12 @@ func (c *compiler) lowerBranches(br0, br1 *ssa.Instruction) {
}

if br0.Opcode() == ssa.OpcodeJump {
_, args, target := br0.BranchData()
_, args, targetBlockID := br0.BranchData()
argExists := len(args) != 0
if argExists && br1 != nil {
panic("BUG: critical edge split failed")
}
target := c.ssaBuilder.BasicBlock(targetBlockID)
if argExists && target.ReturnBlock() {
if len(args) > 0 {
c.mach.LowerReturns(args)
Expand Down
48 changes: 29 additions & 19 deletions internal/engine/wazevo/backend/isa/amd64/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ type (

labelResolutionPends []labelResolutionPend

// jmpTableTargets holds the labels of the jump table targets.
jmpTableTargets [][]uint32
consts []_const
// jmpTableTargetNext is the index to the jmpTableTargets slice to be used for the next jump table.
jmpTableTargetsNext int
consts []_const

constSwizzleMaskConstIndex, constSqmulRoundSatIndex,
constI8x16SHLMaskTableIndex, constI8x16LogicalSHRMaskTableIndex,
Expand Down Expand Up @@ -131,7 +134,7 @@ func (m *machine) Reset() {
m.maxRequiredStackSizeForCalls = 0

m.amodePool.Reset()
m.jmpTableTargets = m.jmpTableTargets[:0]
m.jmpTableTargetsNext = 0
m.constSwizzleMaskConstIndex = -1
m.constSqmulRoundSatIndex = -1
m.constI8x16SHLMaskTableIndex = -1
Expand Down Expand Up @@ -187,46 +190,53 @@ func (m *machine) LowerSingleBranch(b *ssa.Instruction) {
ectx := m.ectx
switch b.Opcode() {
case ssa.OpcodeJump:
_, _, targetBlk := b.BranchData()
_, _, targetBlkID := b.BranchData()
if b.IsFallthroughJump() {
return
}
jmp := m.allocateInstr()
target := ectx.GetOrAllocateSSABlockLabel(targetBlk)
target := ectx.GetOrAllocateSSABlockLabel(m.c.SSABuilder().BasicBlock(targetBlkID))
if target == backend.LabelReturn {
jmp.asRet()
} else {
jmp.asJmp(newOperandLabel(target))
}
m.insert(jmp)
case ssa.OpcodeBrTable:
index, target := b.BrTableData()
m.lowerBrTable(index, target)
index, targetBlkIDs := b.BrTableData()
m.lowerBrTable(index, targetBlkIDs)
default:
panic("BUG: unexpected branch opcode" + b.Opcode().String())
}
}

func (m *machine) addJmpTableTarget(targets []ssa.BasicBlock) (index int) {
// TODO: reuse the slice!
labels := make([]uint32, len(targets))
for j, target := range targets {
labels[j] = uint32(m.ectx.GetOrAllocateSSABlockLabel(target))
func (m *machine) addJmpTableTarget(targets ssa.Values) (index int) {
if m.jmpTableTargetsNext == len(m.jmpTableTargets) {
m.jmpTableTargets = append(m.jmpTableTargets, make([]uint32, 0, len(targets.View())))
}

index = m.jmpTableTargetsNext
m.jmpTableTargetsNext++
m.jmpTableTargets[index] = m.jmpTableTargets[index][:0]
for _, targetBlockID := range targets.View() {
target := m.c.SSABuilder().BasicBlock(ssa.BasicBlockID(targetBlockID))
m.jmpTableTargets[index] = append(m.jmpTableTargets[index],
uint32(m.ectx.GetOrAllocateSSABlockLabel(target)))
}
index = len(m.jmpTableTargets)
m.jmpTableTargets = append(m.jmpTableTargets, labels)
return
}

var condBranchMatches = [...]ssa.Opcode{ssa.OpcodeIcmp, ssa.OpcodeFcmp}

func (m *machine) lowerBrTable(index ssa.Value, targets []ssa.BasicBlock) {
func (m *machine) lowerBrTable(index ssa.Value, targets ssa.Values) {
_v := m.getOperand_Reg(m.c.ValueDefinition(index))
v := m.copyToTmp(_v.reg())

targetCount := len(targets.View())

// First, we need to do the bounds check.
maxIndex := m.c.AllocateVReg(ssa.TypeI32)
m.lowerIconst(maxIndex, uint64(len(targets)-1), false)
m.lowerIconst(maxIndex, uint64(targetCount-1), false)
cmp := m.allocateInstr().asCmpRmiR(true, newOperandReg(maxIndex), v, false)
m.insert(cmp)

Expand Down Expand Up @@ -255,23 +265,23 @@ func (m *machine) lowerBrTable(index ssa.Value, targets []ssa.BasicBlock) {

jmpTable := m.allocateInstr()
targetSliceIndex := m.addJmpTableTarget(targets)
jmpTable.asJmpTableSequence(targetSliceIndex, len(targets))
jmpTable.asJmpTableSequence(targetSliceIndex, targetCount)
m.insert(jmpTable)
}

// LowerConditionalBranch implements backend.Machine.
func (m *machine) LowerConditionalBranch(b *ssa.Instruction) {
exctx := m.ectx
cval, args, targetBlk := b.BranchData()
cval, args, targetBlkID := b.BranchData()
if len(args) > 0 {
panic(fmt.Sprintf(
"conditional branch shouldn't have args; likely a bug in critical edge splitting: from %s to %s",
exctx.CurrentSSABlk,
targetBlk,
targetBlkID,
))
}

target := exctx.GetOrAllocateSSABlockLabel(targetBlk)
target := exctx.GetOrAllocateSSABlockLabel(m.c.SSABuilder().BasicBlock(targetBlkID))
cvalDef := m.c.ValueDefinition(cval)

switch m.c.MatchInstrOneOf(cvalDef, condBranchMatches[:]) {
Expand Down
17 changes: 10 additions & 7 deletions internal/engine/wazevo/backend/isa/arm64/lower_instr.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ func (m *machine) LowerSingleBranch(br *ssa.Instruction) {
ectx := m.executableContext
switch br.Opcode() {
case ssa.OpcodeJump:
_, _, targetBlk := br.BranchData()
_, _, targetBlkID := br.BranchData()
if br.IsFallthroughJump() {
return
}
b := m.allocateInstr()
targetBlk := m.compiler.SSABuilder().BasicBlock(targetBlkID)
target := ectx.GetOrAllocateSSABlockLabel(targetBlk)
if target == labelReturn {
b.asRet()
Expand All @@ -40,7 +41,8 @@ func (m *machine) LowerSingleBranch(br *ssa.Instruction) {
}

func (m *machine) lowerBrTable(i *ssa.Instruction) {
index, targets := i.BrTableData()
index, targetBlockIDs := i.BrTableData()
targetBlockCount := len(targetBlockIDs.View())
indexOperand := m.getOperand_NR(m.compiler.ValueDefinition(index), extModeNone)

// Firstly, we have to do the bounds check of the index, and
Expand All @@ -50,7 +52,7 @@ func (m *machine) lowerBrTable(i *ssa.Instruction) {
// subs wzr, index, maxIndexReg
// csel adjustedIndex, maxIndexReg, index, hs ;; if index is higher or equal than maxIndexReg.
maxIndexReg := m.compiler.AllocateVReg(ssa.TypeI32)
m.lowerConstantI32(maxIndexReg, int32(len(targets)-1))
m.lowerConstantI32(maxIndexReg, int32(targetBlockCount-1))
subs := m.allocateInstr()
subs.asALU(aluOpSubS, xzrVReg, indexOperand, operandNR(maxIndexReg), false)
m.insert(subs)
Expand All @@ -61,23 +63,24 @@ func (m *machine) lowerBrTable(i *ssa.Instruction) {

brSequence := m.allocateInstr()

tableIndex := m.addJmpTableTarget(targets)
brSequence.asBrTableSequence(adjustedIndex, tableIndex, len(targets))
tableIndex := m.addJmpTableTarget(targetBlockIDs)
brSequence.asBrTableSequence(adjustedIndex, tableIndex, targetBlockCount)
m.insert(brSequence)
}

// LowerConditionalBranch implements backend.Machine.
func (m *machine) LowerConditionalBranch(b *ssa.Instruction) {
exctx := m.executableContext
cval, args, targetBlk := b.BranchData()
cval, args, targetBlkID := b.BranchData()
if len(args) > 0 {
panic(fmt.Sprintf(
"conditional branch shouldn't have args; likely a bug in critical edge splitting: from %s to %s",
exctx.CurrentSSABlk,
targetBlk,
targetBlkID,
))
}

targetBlk := m.compiler.SSABuilder().BasicBlock(targetBlkID)
target := exctx.GetOrAllocateSSABlockLabel(targetBlk)
cvalDef := m.compiler.ValueDefinition(cval)

Expand Down
23 changes: 15 additions & 8 deletions internal/engine/wazevo/backend/isa/arm64/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type (

// jmpTableTargets holds the labels of the jump table targets.
jmpTableTargets [][]uint32
// jmpTableTargetNext is the index to the jmpTableTargets slice to be used for the next jump table.
jmpTableTargetsNext int

// spillSlotSize is the size of the stack slot in bytes used for spilling registers.
// During the execution of the function, the stack looks like:
Expand Down Expand Up @@ -151,7 +153,7 @@ func (m *machine) Reset() {
m.unresolvedAddressModes = m.unresolvedAddressModes[:0]
m.maxRequiredStackSizeForCalls = 0
m.executableContext.Reset()
m.jmpTableTargets = m.jmpTableTargets[:0]
m.jmpTableTargetsNext = 0
m.amodePool.Reset()
}

Expand Down Expand Up @@ -508,13 +510,18 @@ func (m *machine) frameSize() int64 {
return s
}

func (m *machine) addJmpTableTarget(targets []ssa.BasicBlock) (index int) {
// TODO: reuse the slice!
labels := make([]uint32, len(targets))
for j, target := range targets {
labels[j] = uint32(m.executableContext.GetOrAllocateSSABlockLabel(target))
func (m *machine) addJmpTableTarget(targets ssa.Values) (index int) {
if m.jmpTableTargetsNext == len(m.jmpTableTargets) {
m.jmpTableTargets = append(m.jmpTableTargets, make([]uint32, 0, len(targets.View())))
}

index = m.jmpTableTargetsNext
m.jmpTableTargetsNext++
m.jmpTableTargets[index] = m.jmpTableTargets[index][:0]
for _, targetBlockID := range targets.View() {
target := m.compiler.SSABuilder().BasicBlock(ssa.BasicBlockID(targetBlockID))
m.jmpTableTargets[index] = append(m.jmpTableTargets[index],
uint32(m.executableContext.GetOrAllocateSSABlockLabel(target)))
}
index = len(m.jmpTableTargets)
m.jmpTableTargets = append(m.jmpTableTargets, labels)
return
}
4 changes: 3 additions & 1 deletion internal/engine/wazevo/backend/isa/arm64/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func newSetupWithMockContext() (*mockCompiler, ssa.Builder, *machine) {
m := NewBackend().(*machine)
m.SetCompiler(ctx)
ssaB := ssa.NewBuilder()
ctx.ssaBuilder = ssaB
blk := ssaB.AllocateBasicBlock()
ssaB.SetCurrentBlock(blk)
return ctx, ssaB, m
Expand All @@ -57,6 +58,7 @@ type mockCompiler struct {
definitions map[ssa.Value]*backend.SSAValueDefinition
sigs map[ssa.SignatureID]*ssa.Signature
typeOf map[regalloc.VRegID]ssa.Type
ssaBuilder ssa.Builder
relocs []backend.RelocationInfo
buf []byte
}
Expand All @@ -68,7 +70,7 @@ func (m *mockCompiler) GetFunctionABI(sig *ssa.Signature) *backend.FunctionABI {
panic("implement me")
}

func (m *mockCompiler) SSABuilder() ssa.Builder { return nil }
func (m *mockCompiler) SSABuilder() ssa.Builder { return m.ssaBuilder }

func (m *mockCompiler) LoopNestingForestRoots() []ssa.BasicBlock { panic("TODO") }

Expand Down
13 changes: 7 additions & 6 deletions internal/engine/wazevo/frontend/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -4068,31 +4068,32 @@ func (c *Compiler) lowerBrTable(labels []uint32, index ssa.Value) {
numArgs = len(f.blockType.Results)
}

targets := make([]ssa.BasicBlock, len(labels))
varPool := builder.VarLengthPool()
trampolineBlockIDs := varPool.Allocate(len(labels))

// We need trampoline blocks since depending on the target block structure, we might end up inserting moves before jumps,
// which cannot be done with br_table. Instead, we can do such per-block moves in the trampoline blocks.
// At the linking phase (very end of the backend), we can remove the unnecessary jumps, and therefore no runtime overhead.
currentBlk := builder.CurrentBlock()
for i, l := range labels {
for _, l := range labels {
// Args are always on the top of the stack. Note that we should not share the args slice
// among the jump instructions since the args are modified during passes (e.g. redundant phi elimination).
args := c.nPeekDup(numArgs)
targetBlk, _ := state.brTargetArgNumFor(l)
trampoline := builder.AllocateBasicBlock()
builder.SetCurrentBlock(trampoline)
c.insertJumpToBlock(args, targetBlk)
targets[i] = trampoline
trampolineBlockIDs = trampolineBlockIDs.Append(builder.VarLengthPool(), ssa.Value(trampoline.ID()))
}
builder.SetCurrentBlock(currentBlk)

// If the target block has no arguments, we can just jump to the target block.
brTable := builder.AllocateInstruction()
brTable.AsBrTable(index, targets)
brTable.AsBrTable(index, trampolineBlockIDs)
builder.InsertInstruction(brTable)

for _, trampoline := range targets {
builder.Seal(trampoline)
for _, trampolineID := range trampolineBlockIDs.View() {
builder.Seal(builder.BasicBlock(ssa.BasicBlockID(trampolineID)))
}
}

Expand Down
21 changes: 10 additions & 11 deletions internal/engine/wazevo/ssa/basic_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ type BasicBlock interface {
// The returned Value is the definition of the param in this block.
Param(i int) Value

// InsertInstruction inserts an instruction that implements Value into the tail of this block.
InsertInstruction(raw *Instruction)

// Root returns the root instruction of this block.
Root() *Instruction

Expand Down Expand Up @@ -208,8 +205,8 @@ func (bb *basicBlock) Sealed() bool {
return bb.sealed
}

// InsertInstruction implements BasicBlock.InsertInstruction.
func (bb *basicBlock) InsertInstruction(next *Instruction) {
// insertInstruction implements BasicBlock.InsertInstruction.
func (bb *basicBlock) insertInstruction(b *builder, next *Instruction) {
current := bb.currentInstr
if current != nil {
current.next = next
Expand All @@ -221,12 +218,12 @@ func (bb *basicBlock) InsertInstruction(next *Instruction) {

switch next.opcode {
case OpcodeJump, OpcodeBrz, OpcodeBrnz:
target := next.blk.(*basicBlock)
target.addPred(bb, next)
target := BasicBlockID(next.rValue)
b.basicBlock(target).addPred(bb, next)
case OpcodeBrTable:
for _, _target := range next.targets {
target := _target.(*basicBlock)
target.addPred(bb, next)
for _, _target := range next.rValues.View() {
target := BasicBlockID(_target)
b.basicBlock(target).addPred(bb, next)
}
}
}
Expand Down Expand Up @@ -339,7 +336,9 @@ func (bb *basicBlock) validate(b *builder) {
if len(bb.preds) > 0 {
for _, pred := range bb.preds {
if pred.branch.opcode != OpcodeBrTable {
if target := pred.branch.blk; target != bb {
blockID := int(pred.branch.rValue)
target := b.basicBlocksPool.View(blockID)
if target != bb {
panic(fmt.Sprintf("BUG: '%s' is not branch to %s, but to %s",
pred.branch.Format(b), bb.Name(), target.Name()))
}
Expand Down
17 changes: 16 additions & 1 deletion internal/engine/wazevo/ssa/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ type Builder interface {

// InsertZeroValue inserts a zero value constant instruction of the given type.
InsertZeroValue(t Type)

// BasicBlock returns the BasicBlock of the given ID.
BasicBlock(id BasicBlockID) BasicBlock
}

// NewBuilder returns a new Builder implementation.
Expand Down Expand Up @@ -214,6 +217,18 @@ type redundantParam struct {
uniqueValue Value
}

// BasicBlock implements Builder.BasicBlock.
func (b *builder) BasicBlock(id BasicBlockID) BasicBlock {
return b.basicBlock(id)
}

func (b *builder) basicBlock(id BasicBlockID) *basicBlock {
if id == basicBlockIDReturnBlock {
return b.returnBlk
}
return b.basicBlocksPool.View(int(id))
}

// InsertZeroValue implements Builder.InsertZeroValue.
func (b *builder) InsertZeroValue(t Type) {
if b.zeros[t].Valid() {
Expand Down Expand Up @@ -362,7 +377,7 @@ func (b *builder) Idom(blk BasicBlock) BasicBlock {

// InsertInstruction implements Builder.InsertInstruction.
func (b *builder) InsertInstruction(instr *Instruction) {
b.currentBB.InsertInstruction(instr)
b.currentBB.insertInstruction(b, instr)

if l := b.currentSourceOffset; l.Valid() {
// Emit the source offset info only when the instruction has side effect because
Expand Down
Loading

0 comments on commit b9571df

Please sign in to comment.