Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssa: removes map use for Value aliasing #2285

Merged
merged 2 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions internal/engine/wazevo/backend/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ type compiler struct {
ssaValueToVRegs [] /* VRegID to */ regalloc.VReg
// ssaValueDefinitions maps ssa.ValueID to its definition.
ssaValueDefinitions []SSAValueDefinition
// ssaValueRefCounts is a cached list obtained by ssa.Builder.ValueRefCounts().
ssaValueRefCounts []int
// returnVRegs is the list of virtual registers that store the return values.
returnVRegs []regalloc.VReg
varEdges [][2]regalloc.VReg
Expand Down Expand Up @@ -206,8 +204,7 @@ func (c *compiler) setCurrentGroupID(gid ssa.InstructionGroupID) {
// assignVirtualRegisters assigns a virtual register to each ssa.ValueID Valid in the ssa.Builder.
func (c *compiler) assignVirtualRegisters() {
builder := c.ssaBuilder
refCounts := builder.ValueRefCounts()
c.ssaValueRefCounts = refCounts
refCounts := builder.ValuesInfo()

need := len(refCounts)
if need >= len(c.ssaValueToVRegs) {
Expand Down Expand Up @@ -242,7 +239,7 @@ func (c *compiler) assignVirtualRegisters() {
c.ssaValueDefinitions[id] = SSAValueDefinition{
Instr: cur,
N: 0,
RefCount: refCounts[id],
RefCount: refCounts[id].RefCount,
}
c.ssaTypeOfVRegID[vReg.ID()] = ssaTyp
N++
Expand All @@ -255,7 +252,7 @@ func (c *compiler) assignVirtualRegisters() {
c.ssaValueDefinitions[id] = SSAValueDefinition{
Instr: cur,
N: N,
RefCount: refCounts[id],
RefCount: refCounts[id].RefCount,
}
c.ssaTypeOfVRegID[vReg.ID()] = ssaTyp
N++
Expand Down
3 changes: 2 additions & 1 deletion internal/engine/wazevo/backend/compiler_lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ func (c *compiler) lowerFunctionArguments(entry ssa.BasicBlock) {
mach := c.mach

c.tmpVals = c.tmpVals[:0]
data := c.ssaBuilder.ValuesInfo()
for i := 0; i < entry.Params(); i++ {
p := entry.Param(i)
if c.ssaValueRefCounts[p.ID()] > 0 {
if data[p.ID()].RefCount > 0 {
c.tmpVals = append(c.tmpVals, p)
} else {
// If the argument is not used, we can just pass an invalid value.
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/wazevo/backend/vdef.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type SSAValueDefinition struct {
// N is the index of the return value in the instr's return values list.
N int
// RefCount is the number of references to the result.
RefCount int
RefCount uint32
}

func (d *SSAValueDefinition) IsFromInstr() bool {
Expand Down
48 changes: 32 additions & 16 deletions internal/engine/wazevo/ssa/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ type Builder interface {
// Returns nil if there's no unseen BasicBlock.
BlockIteratorNext() BasicBlock

// ValueRefCounts returns the map of ValueID to its reference count.
// The returned slice must not be modified.
ValueRefCounts() []int
// ValuesInfo returns the data per Value used to lower the SSA in backend.
// This is indexed by ValueID.
ValuesInfo() []ValueInfo

// BlockIteratorReversePostOrderBegin is almost the same as BlockIteratorBegin except it returns the BasicBlock in the reverse post-order.
// This is available after RunPasses is run.
Expand Down Expand Up @@ -143,7 +143,6 @@ func NewBuilder() Builder {
varLengthPool: wazevoapi.NewVarLengthPool[Value](),
valueAnnotations: make(map[ValueID]string),
signatures: make(map[SignatureID]*Signature),
valueIDAliases: make(map[ValueID]Value),
returnBlk: &basicBlock{id: basicBlockIDReturnBlock},
}
}
Expand All @@ -166,12 +165,11 @@ type builder struct {
// nextVariable is used by builder.AllocateVariable.
nextVariable Variable

valueIDAliases map[ValueID]Value
// valueAnnotations contains the annotations for each Value, only used for debugging.
valueAnnotations map[ValueID]string

// valueRefCounts is used to lower the SSA in backend, and will be calculated
// by the last SSA-level optimization pass.
valueRefCounts []int
// valuesInfo contains the data per Value used to lower the SSA in backend. This is indexed by ValueID.
valuesInfo []ValueInfo

// dominators stores the immediate dominator of each BasicBlock.
// The index is blockID of the BasicBlock.
Expand Down Expand Up @@ -206,6 +204,13 @@ type builder struct {
zeros [typeEnd]Value
}

// ValueInfo contains the data per Value used to lower the SSA in backend.
type ValueInfo struct {
// RefCount is the reference count of the Value.
RefCount uint32
alias Value
}

// redundantParam is a pair of the index of the redundant parameter and the Value.
// This is used to eliminate the redundant parameters in the optimization pass.
type redundantParam struct {
Expand Down Expand Up @@ -285,8 +290,7 @@ func (b *builder) Init(s *Signature) {

for v := ValueID(0); v < b.nextValueID; v++ {
delete(b.valueAnnotations, v)
delete(b.valueIDAliases, v)
b.valueRefCounts[v] = 0
b.valuesInfo[v] = ValueInfo{alias: ValueInvalid}
b.valueIDToInstruction[v] = nil
}
b.nextValueID = 0
Expand Down Expand Up @@ -676,15 +680,24 @@ func (b *builder) blockIteratorReversePostOrderNext() *basicBlock {
}
}

// ValueRefCounts implements Builder.ValueRefCounts.
func (b *builder) ValueRefCounts() []int {
return b.valueRefCounts
// ValuesInfo implements Builder.ValuesInfo.
func (b *builder) ValuesInfo() []ValueInfo {
return b.valuesInfo
}

// alias records the alias of the given values. The alias(es) will be
// eliminated in the optimization pass via resolveArgumentAlias.
func (b *builder) alias(dst, src Value) {
b.valueIDAliases[dst.ID()] = src
did := int(dst.ID())
if did >= len(b.valuesInfo) {
l := did + 1 - len(b.valuesInfo)
b.valuesInfo = append(b.valuesInfo, make([]ValueInfo, l)...)
view := b.valuesInfo[len(b.valuesInfo)-l:]
for i := range view {
view[i].alias = ValueInvalid
}
}
b.valuesInfo[did].alias = src
}

// resolveArgumentAlias resolves the alias of the arguments of the given instruction.
Expand All @@ -709,10 +722,13 @@ func (b *builder) resolveArgumentAlias(instr *Instruction) {

// resolveAlias resolves the alias of the given value.
func (b *builder) resolveAlias(v Value) Value {
info := b.valuesInfo
l := ValueID(len(info))
// Some aliases are chained, so we need to resolve them recursively.
for {
if src, ok := b.valueIDAliases[v.ID()]; ok {
v = src
vid := v.ID()
if vid < l && info[vid].alias.Valid() {
v = info[vid].alias
} else {
break
}
Expand Down
26 changes: 26 additions & 0 deletions internal/engine/wazevo/ssa/builder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package ssa

import (
"testing"

"github.com/tetratelabs/wazero/internal/testing/require"
)

func TestBuilder_resolveAlias(t *testing.T) {
b := NewBuilder().(*builder)
v1 := b.allocateValue(TypeI32)
v2 := b.allocateValue(TypeI32)
v3 := b.allocateValue(TypeI32)
v4 := b.allocateValue(TypeI32)
v5 := b.allocateValue(TypeI32)

b.alias(v1, v2)
b.alias(v2, v3)
b.alias(v3, v4)
b.alias(v4, v5)
require.Equal(t, v5, b.resolveAlias(v1))
require.Equal(t, v5, b.resolveAlias(v2))
require.Equal(t, v5, b.resolveAlias(v3))
require.Equal(t, v5, b.resolveAlias(v4))
require.Equal(t, v5, b.resolveAlias(v5))
}
12 changes: 9 additions & 3 deletions internal/engine/wazevo/ssa/pass.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,13 @@ func passRedundantPhiEliminationOpt(b *builder) {
// TODO: the algorithm here might not be efficient. Get back to this later.
func passDeadCodeEliminationOpt(b *builder) {
nvid := int(b.nextValueID)
if nvid >= len(b.valueRefCounts) {
b.valueRefCounts = append(b.valueRefCounts, make([]int, nvid-len(b.valueRefCounts)+1)...)
if nvid >= len(b.valuesInfo) {
l := nvid - len(b.valuesInfo) + 1
b.valuesInfo = append(b.valuesInfo, make([]ValueInfo, l)...)
view := b.valuesInfo[len(b.valuesInfo)-l:]
for i := range view {
view[i].alias = ValueInvalid
}
}
if nvid >= len(b.valueIDToInstruction) {
b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, nvid-len(b.valueIDToInstruction)+1)...)
Expand Down Expand Up @@ -356,7 +361,8 @@ func (b *builder) incRefCount(id ValueID, from *Instruction) {
if wazevoapi.SSALoggingEnabled {
fmt.Printf("v%d referenced from %v\n", id, from.Format(b))
}
b.valueRefCounts[id]++
info := &b.valuesInfo[id]
info.RefCount++
}

// passNopInstElimination eliminates the instructions which is essentially a no-op.
Expand Down
6 changes: 3 additions & 3 deletions internal/engine/wazevo/ssa/pass_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ blk2: () <-- (blk1)
require.True(t, jmp.live)
require.True(t, ret.live)

require.Equal(t, 1, b.valueRefCounts[refOnceVal.ID()])
require.Equal(t, 1, b.valueRefCounts[addRes.ID()])
require.Equal(t, 3, b.valueRefCounts[refThriceVal.ID()])
require.Equal(t, uint32(1), b.valuesInfo[refOnceVal.ID()].RefCount)
require.Equal(t, uint32(1), b.valuesInfo[addRes.ID()].RefCount)
require.Equal(t, uint32(3), b.valuesInfo[refThriceVal.ID()].RefCount)
}
},
before: `
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/wazevo/ssa/vs.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (v Value) formatWithType(b Builder) (ret string) {
if wazevoapi.SSALoggingEnabled { // This is useful to check live value analysis bugs.
if bd := b.(*builder); bd.donePostBlockLayoutPasses {
id := v.ID()
ret += fmt.Sprintf("(ref=%d)", bd.valueRefCounts[id])
ret += fmt.Sprintf("(ref=%d)", bd.valuesInfo[id].RefCount)
}
}
return ret
Expand Down
Loading