1
1
from collections import defaultdict
2
2
from dataclasses import dataclass , field
3
- from typing import overload
3
+ from typing import cast , overload
4
4
5
5
from xdsl .backend .register_queue import RegisterQueue
6
+ from xdsl .dialects .builtin import IntAttr
6
7
from xdsl .dialects .riscv import FloatRegisterType , IntRegisterType , Registers
7
8
8
9
@@ -30,14 +31,17 @@ class RiscvRegisterQueue(RegisterQueue[IntRegisterType | FloatRegisterType]):
30
31
_fj_idx : int = 0
31
32
"""Next `fj` register index."""
32
33
33
- reserved_int_registers : defaultdict [IntRegisterType , int ] = field (
34
- default_factory = lambda : defaultdict [IntRegisterType , int ](lambda : 0 )
35
- | {r : 1 for r in RiscvRegisterQueue .DEFAULT_RESERVED_REGISTERS }
34
+ reserved_int_registers : defaultdict [int , int ] = field (
35
+ default_factory = lambda : defaultdict [int , int ](lambda : 0 )
36
+ | {
37
+ cast (IntAttr , r .index ).data : 1
38
+ for r in RiscvRegisterQueue .DEFAULT_RESERVED_REGISTERS
39
+ }
36
40
)
37
41
"Integer registers unavailable to be used by the register allocator."
38
42
39
- reserved_float_registers : defaultdict [FloatRegisterType , int ] = field (
40
- default_factory = lambda : defaultdict [FloatRegisterType , int ](lambda : 0 )
43
+ reserved_float_registers : defaultdict [int , int ] = field (
44
+ default_factory = lambda : defaultdict [int , int ](lambda : 0 )
41
45
)
42
46
"Floating-point registers unavailable to be used by the register allocator."
43
47
@@ -55,13 +59,16 @@ def push(self, reg: IntRegisterType | FloatRegisterType) -> None:
55
59
"""
56
60
Return a register to be made available for allocation.
57
61
"""
58
- if reg in self .reserved_int_registers or reg in self .reserved_float_registers :
59
- return
60
- if not reg .is_allocated :
62
+ if not isinstance (reg .index , IntAttr ):
61
63
raise ValueError ("Cannot push an unallocated register" )
64
+
62
65
if isinstance (reg , IntRegisterType ):
66
+ if reg .index .data in self .reserved_int_registers :
67
+ return
63
68
self .available_int_registers .append (reg )
64
69
else :
70
+ if reg .index .data in self .reserved_float_registers :
71
+ return
65
72
self .available_float_registers .append (reg )
66
73
67
74
@overload
@@ -97,7 +104,8 @@ def pop(
97
104
else self .reserved_float_registers
98
105
)
99
106
100
- assert reg not in reserved_registers , (
107
+ assert isinstance (reg .index , IntAttr )
108
+ assert reg .index .data not in reserved_registers , (
101
109
f"Cannot pop a reserved register ({ reg .register_name .data } ), it must have been reserved while available."
102
110
)
103
111
return reg
@@ -110,28 +118,30 @@ def reserve_register(self, reg: IntRegisterType | FloatRegisterType) -> None:
110
118
It is invalid to reserve a register that is available, and popping it before
111
119
unreserving a register will result in an AssertionError.
112
120
"""
121
+ assert isinstance (reg .index , IntAttr )
113
122
if isinstance (reg , IntRegisterType ):
114
- self .reserved_int_registers [reg ] += 1
123
+ self .reserved_int_registers [reg . index . data ] += 1
115
124
if isinstance (reg , FloatRegisterType ):
116
- self .reserved_float_registers [reg ] += 1
125
+ self .reserved_float_registers [reg . index . data ] += 1
117
126
118
127
def unreserve_register (self , reg : IntRegisterType | FloatRegisterType ) -> None :
119
128
"""
120
129
Decrease the reservation count for a register. If the reservation count is 0, make
121
130
the register available for allocation.
122
131
"""
132
+ assert isinstance (reg .index , IntAttr )
123
133
if isinstance (reg , IntRegisterType ):
124
- if reg not in self .reserved_int_registers :
134
+ if reg . index . data not in self .reserved_int_registers :
125
135
raise ValueError (f"Cannot unreserve register { reg .register_name } " )
126
- self .reserved_int_registers [reg ] -= 1
127
- if not self .reserved_int_registers [reg ]:
128
- del self .reserved_int_registers [reg ]
136
+ self .reserved_int_registers [reg . index . data ] -= 1
137
+ if not self .reserved_int_registers [reg . index . data ]:
138
+ del self .reserved_int_registers [reg . index . data ]
129
139
if isinstance (reg , FloatRegisterType ):
130
- if reg not in self .reserved_float_registers :
140
+ if reg . index . data not in self .reserved_float_registers :
131
141
raise ValueError (f"Cannot unreserve register { reg .register_name } " )
132
- self .reserved_float_registers [reg ] -= 1
133
- if not self .reserved_float_registers [reg ]:
134
- del self .reserved_float_registers [reg ]
142
+ self .reserved_float_registers [reg . index . data ] -= 1
143
+ if not self .reserved_float_registers [reg . index . data ]:
144
+ del self .reserved_float_registers [reg . index . data ]
135
145
136
146
def limit_registers (self , limit : int ) -> None :
137
147
"""
0 commit comments