-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathssa2bril.py
148 lines (121 loc) · 4.79 KB
/
ssa2bril.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Converts SSA form Bril programs to non-SSA form
"""
import sys
import json
import copy
from typing import cast
from collections import defaultdict
import click
from typing_bril import (
Program,
SSAProgram,
SSAValue,
SSAInstruction,
Variable,
Effect,
Value,
Constant,
)
from basic_blocks import (
BasicBlock,
BasicBlockFunction,
BasicBlockProgram,
program_from_basic_block_program,
)
from ssa_basic_blocks import (
SSABasicBlock,
SSABasicBlockFunction,
SSABasicBlockProgram,
ssa_basic_block_program_from_ssa_program,
)
from cfg import (
control_flow_graph_from_instructions,
)
from bril_labeler import index_to_label_dict_get, apply_labels
from bril_extract import phi_nodes_get
from bril_analyze import is_terminator
def ssa_bb_func_to_bb_func(
ssa_bb_func: SSABasicBlockFunction,
) -> BasicBlockFunction:
"""Given SSA basic blocks to a function compute the non-SSA basic blocks"""
ssa_bb_func = copy.deepcopy(ssa_bb_func)
index_to_label = index_to_label_dict_get(cast(BasicBlockFunction, ssa_bb_func))
ssa_bb_func["instrs"] = cast(
list[SSABasicBlock],
apply_labels(cast(BasicBlockFunction, ssa_bb_func)["instrs"], index_to_label),
)
cfg = control_flow_graph_from_instructions(
cast(list[BasicBlock], ssa_bb_func["instrs"])
)
for block_index, ssa_basic_block in enumerate(ssa_bb_func["instrs"]):
phi_nodes = phi_nodes_get(ssa_basic_block)
for phi_node in phi_nodes:
dest = phi_node["dest"]
type_ = phi_node["type"]
for predecessor_index in cfg.predecessors(block_index):
predecessor_label = index_to_label[predecessor_index]
instructions: list[SSAInstruction] = []
try:
arg = next(
arg
for arg, label in zip(phi_node["args"], phi_node["labels"])
if label == predecessor_label
)
instructions.append(SSAValue(op="id", type=type_, dest=dest, args=[arg]))
except StopIteration:
# Not in Phi node
if isinstance(type_, dict):
if "ptr" in type_:
instructions.append(Constant(
op="const", type="int", dest=Variable(f"{dest}.size"), value=1))
instructions.append(Value(
op="alloc", args=[Variable(f"{dest}.size")], dest=dest, type=type_))
instructions.append(Effect(
op="free", args=[dest]))
else:
constant_value: int | float | bool
match type_:
case "int":
constant_value = int()
case "float":
constant_value = float()
case "bool":
constant_value = bool()
case _:
raise ValueError(f"Encountered unsupported type: {type_}")
instructions.append(Constant(
op="const", type=type_, dest=dest, value=constant_value))
block_to_insert = ssa_bb_func["instrs"][predecessor_index]
if len(block_to_insert) <= 0 or not is_terminator(block_to_insert[-1]):
block_to_insert.extend(instructions)
else:
for instruction in instructions:
block_to_insert.insert(-1, instruction)
for block_index, ssa_basic_block in enumerate(ssa_bb_func["instrs"]):
# Delete phi nodes
ssa_bb_func["instrs"][block_index] = list(
filter(
lambda instr: "op" not in instr
or ("op" in instr and cast(SSAValue, instr)["op"] != "phi"),
ssa_basic_block,
)
)
bb_function = cast(BasicBlockFunction, ssa_bb_func)
return bb_function
def ssa_to_bril(ssa_program: SSAProgram) -> Program:
# Says Program is not Program
ssa_bb_program: SSABasicBlockProgram = ssa_basic_block_program_from_ssa_program(ssa_program) # type: ignore
bb_program: BasicBlockProgram = cast(
BasicBlockProgram, copy.deepcopy(ssa_bb_program)
)
for i, func in enumerate(ssa_bb_program["functions"]):
bb_program["functions"][i] = ssa_bb_func_to_bb_func(func)
return cast(Program, program_from_basic_block_program(bb_program))
@click.command()
def main():
ssa_program: Program = json.load(sys.stdin)
program: SSAProgram = ssa_to_bril(ssa_program)
print(json.dumps(program))
if __name__ == "__main__":
main()