Skip to content

Commit e294346

Browse files
authored
backend: use register index for allocation (#4047)
Now that the index uniquely identifies the register, we can use it in the register queue, helping avoid collisions between different names for the same register.
1 parent fef6e7c commit e294346

File tree

2 files changed

+43
-50
lines changed

2 files changed

+43
-50
lines changed

tests/backend/riscv/test_register_queue.py

+13-30
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from xdsl.backend.riscv.riscv_register_queue import RiscvRegisterQueue
66
from xdsl.dialects import riscv
7+
from xdsl.dialects.builtin import IntAttr
78

89

910
def test_default_reserved_registers():
@@ -51,39 +52,21 @@ def test_push_register():
5152
def test_reserve_register():
5253
register_queue = RiscvRegisterQueue()
5354

54-
register_queue.reserve_register(riscv.IntRegisterType.infinite_register(0))
55-
assert (
56-
register_queue.reserved_int_registers[
57-
riscv.IntRegisterType.infinite_register(0)
58-
]
59-
== 1
60-
)
55+
j0 = riscv.IntRegisterType.infinite_register(0)
56+
assert isinstance(j0.index, IntAttr)
6157

62-
register_queue.reserve_register(riscv.IntRegisterType.infinite_register(0))
63-
assert (
64-
register_queue.reserved_int_registers[
65-
riscv.IntRegisterType.infinite_register(0)
66-
]
67-
== 2
68-
)
58+
register_queue.reserve_register(j0)
59+
assert register_queue.reserved_int_registers[j0.index.data] == 1
6960

70-
register_queue.unreserve_register(riscv.IntRegisterType.infinite_register(0))
71-
assert (
72-
register_queue.reserved_int_registers[
73-
riscv.IntRegisterType.infinite_register(0)
74-
]
75-
== 1
76-
)
61+
register_queue.reserve_register(j0)
62+
assert register_queue.reserved_int_registers[j0.index.data] == 2
7763

78-
register_queue.unreserve_register(riscv.IntRegisterType.infinite_register(0))
79-
assert (
80-
riscv.IntRegisterType.infinite_register(0)
81-
not in register_queue.reserved_int_registers
82-
)
83-
assert (
84-
riscv.IntRegisterType.infinite_register(0)
85-
not in register_queue.available_int_registers
86-
)
64+
register_queue.unreserve_register(j0)
65+
assert register_queue.reserved_int_registers[j0.index.data] == 1
66+
67+
register_queue.unreserve_register(j0)
68+
assert j0 not in register_queue.reserved_int_registers
69+
assert j0 not in register_queue.available_int_registers
8770

8871
# Check assertion error when reserving an available register
8972
reg = register_queue.pop(riscv.IntRegisterType)

xdsl/backend/riscv/riscv_register_queue.py

+30-20
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from collections import defaultdict
22
from dataclasses import dataclass, field
3-
from typing import overload
3+
from typing import cast, overload
44

55
from xdsl.backend.register_queue import RegisterQueue
6+
from xdsl.dialects.builtin import IntAttr
67
from xdsl.dialects.riscv import FloatRegisterType, IntRegisterType, Registers
78

89

@@ -30,14 +31,17 @@ class RiscvRegisterQueue(RegisterQueue[IntRegisterType | FloatRegisterType]):
3031
_fj_idx: int = 0
3132
"""Next `fj` register index."""
3233

33-
reserved_int_registers: defaultdict[IntRegisterType, int] = field(
34-
default_factory=lambda: defaultdict[IntRegisterType, int](lambda: 0)
35-
| {r: 1 for r in RiscvRegisterQueue.DEFAULT_RESERVED_REGISTERS}
34+
reserved_int_registers: defaultdict[int, int] = field(
35+
default_factory=lambda: defaultdict[int, int](lambda: 0)
36+
| {
37+
cast(IntAttr, r.index).data: 1
38+
for r in RiscvRegisterQueue.DEFAULT_RESERVED_REGISTERS
39+
}
3640
)
3741
"Integer registers unavailable to be used by the register allocator."
3842

39-
reserved_float_registers: defaultdict[FloatRegisterType, int] = field(
40-
default_factory=lambda: defaultdict[FloatRegisterType, int](lambda: 0)
43+
reserved_float_registers: defaultdict[int, int] = field(
44+
default_factory=lambda: defaultdict[int, int](lambda: 0)
4145
)
4246
"Floating-point registers unavailable to be used by the register allocator."
4347

@@ -55,13 +59,16 @@ def push(self, reg: IntRegisterType | FloatRegisterType) -> None:
5559
"""
5660
Return a register to be made available for allocation.
5761
"""
58-
if reg in self.reserved_int_registers or reg in self.reserved_float_registers:
59-
return
60-
if not reg.is_allocated:
62+
if not isinstance(reg.index, IntAttr):
6163
raise ValueError("Cannot push an unallocated register")
64+
6265
if isinstance(reg, IntRegisterType):
66+
if reg.index.data in self.reserved_int_registers:
67+
return
6368
self.available_int_registers.append(reg)
6469
else:
70+
if reg.index.data in self.reserved_float_registers:
71+
return
6572
self.available_float_registers.append(reg)
6673

6774
@overload
@@ -97,7 +104,8 @@ def pop(
97104
else self.reserved_float_registers
98105
)
99106

100-
assert reg not in reserved_registers, (
107+
assert isinstance(reg.index, IntAttr)
108+
assert reg.index.data not in reserved_registers, (
101109
f"Cannot pop a reserved register ({reg.register_name.data}), it must have been reserved while available."
102110
)
103111
return reg
@@ -110,28 +118,30 @@ def reserve_register(self, reg: IntRegisterType | FloatRegisterType) -> None:
110118
It is invalid to reserve a register that is available, and popping it before
111119
unreserving a register will result in an AssertionError.
112120
"""
121+
assert isinstance(reg.index, IntAttr)
113122
if isinstance(reg, IntRegisterType):
114-
self.reserved_int_registers[reg] += 1
123+
self.reserved_int_registers[reg.index.data] += 1
115124
if isinstance(reg, FloatRegisterType):
116-
self.reserved_float_registers[reg] += 1
125+
self.reserved_float_registers[reg.index.data] += 1
117126

118127
def unreserve_register(self, reg: IntRegisterType | FloatRegisterType) -> None:
119128
"""
120129
Decrease the reservation count for a register. If the reservation count is 0, make
121130
the register available for allocation.
122131
"""
132+
assert isinstance(reg.index, IntAttr)
123133
if isinstance(reg, IntRegisterType):
124-
if reg not in self.reserved_int_registers:
134+
if reg.index.data not in self.reserved_int_registers:
125135
raise ValueError(f"Cannot unreserve register {reg.register_name}")
126-
self.reserved_int_registers[reg] -= 1
127-
if not self.reserved_int_registers[reg]:
128-
del self.reserved_int_registers[reg]
136+
self.reserved_int_registers[reg.index.data] -= 1
137+
if not self.reserved_int_registers[reg.index.data]:
138+
del self.reserved_int_registers[reg.index.data]
129139
if isinstance(reg, FloatRegisterType):
130-
if reg not in self.reserved_float_registers:
140+
if reg.index.data not in self.reserved_float_registers:
131141
raise ValueError(f"Cannot unreserve register {reg.register_name}")
132-
self.reserved_float_registers[reg] -= 1
133-
if not self.reserved_float_registers[reg]:
134-
del self.reserved_float_registers[reg]
142+
self.reserved_float_registers[reg.index.data] -= 1
143+
if not self.reserved_float_registers[reg.index.data]:
144+
del self.reserved_float_registers[reg.index.data]
135145

136146
def limit_registers(self, limit: int) -> None:
137147
"""

0 commit comments

Comments
 (0)