diff --git a/calyx-py/calyx/builder.py b/calyx-py/calyx/builder.py index 3bbf0d867e..d795faa833 100644 --- a/calyx-py/calyx/builder.py +++ b/calyx-py/calyx/builder.py @@ -48,6 +48,8 @@ def import_(self, filename: str): class ComponentBuilder: """Builds Calyx components definitions.""" + next_gen_idx = 0 + def __init__( self, prog: Builder, @@ -71,6 +73,15 @@ def __init__( for cell in cells: self.index[cell.id.name] = CellBuilder(cell) self.continuous = GroupBuilder(None, self) + self.next_gen_idx = 0 + + def generate_name(self, prefix: str) -> str: + """Generate a unique name with the given prefix.""" + while True: + self.next_gen_idx += 1 + name = f"{prefix}_{self.next_gen_idx}" + if name not in self.index: + return name def input(self, name: str, size: int) -> ExprBuilder: """Declare an input port on the component. @@ -180,8 +191,8 @@ def cell( self, name: str, comp: Union[ast.CompInst, ComponentBuilder], - is_external=False, - is_ref=False, + is_external: bool = False, + is_ref: bool = False, ) -> CellBuilder: """Declare a cell in the component. Returns a cell builder.""" # If we get a (non-primitive) component builder, instantiate it @@ -223,16 +234,18 @@ def comp_instance( return self.cell(cell_name, ast.CompInst(comp_name, [])) - def reg(self, name: str, size: int, is_ref=False) -> CellBuilder: + def reg(self, name: str, size: int, is_ref: bool = False) -> CellBuilder: """Generate a StdReg cell.""" return self.cell(name, ast.Stdlib.register(size), False, is_ref) - def wire(self, name: str, size: int, is_ref=False) -> CellBuilder: - """Generate a StdReg cell.""" + def wire(self, name: str, size: int, is_ref: bool = False) -> CellBuilder: + """Generate a StdWire cell.""" return self.cell(name, ast.Stdlib.wire(size), False, is_ref) - def slice(self, name: str, in_width: int, out_width, is_ref=False) -> CellBuilder: - """Generate a StdReg cell.""" + def slice( + self, name: str, in_width: int, out_width, is_ref: bool = False + ) -> CellBuilder: + """Generate a StdSlice cell.""" return self.cell(name, ast.Stdlib.slice(in_width, out_width), False, is_ref) def const(self, name: str, width: int, value: int) -> CellBuilder: @@ -245,8 +258,8 @@ def mem_d1( bitwidth: int, len: int, idx_size: int, - is_external=False, - is_ref=False, + is_external: bool = False, + is_ref: bool = False, ) -> CellBuilder: """Generate a StdMemD1 cell.""" return self.cell( @@ -259,8 +272,8 @@ def seq_mem_d1( bitwidth: int, len: int, idx_size: int, - is_external=False, - is_ref=False, + is_external: bool = False, + is_ref: bool = False, ) -> CellBuilder: """Generate a SeqMemD1 cell.""" self.prog.import_("primitives/memories.futil") @@ -268,53 +281,64 @@ def seq_mem_d1( name, ast.Stdlib.seq_mem_d1(bitwidth, len, idx_size), is_external, is_ref ) - def add(self, name: str, size: int, signed=False) -> CellBuilder: - """Generate a StdAdd cell.""" + def binary( + self, + operation: str, + size: int, + name: Optional[str] = None, + signed: bool = False, + ) -> CellBuilder: + """Generate a binary cell of the kind specified in `operation`.""" self.prog.import_("primitives/binary_operators.futil") - return self.cell(name, ast.Stdlib.op("add", size, signed)) + name = name or self.generate_name(operation) + assert isinstance(name, str) + return self.cell(name, ast.Stdlib.op(operation, size, signed)) + + def add(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: + """Generate a StdAdd cell.""" + return self.binary("add", size, name, signed) - def sub(self, name: str, size: int, signed=False): + def sub(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: """Generate a StdSub cell.""" - self.prog.import_("primitives/binary_operators.futil") - return self.cell(name, ast.Stdlib.op("sub", size, signed)) + return self.binary("sub", size, name, signed) - def gt(self, name: str, size: int, signed=False): + def gt(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: """Generate a StdGt cell.""" - self.prog.import_("primitives/binary_operators.futil") - return self.cell(name, ast.Stdlib.op("gt", size, signed)) + return self.binary("gt", size, name, signed) - def lt(self, name: str, size: int, signed=False): + def lt(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: """Generate a StdLt cell.""" - self.prog.import_("primitives/binary_operators.futil") - return self.cell(name, ast.Stdlib.op("lt", size, signed)) + return self.binary("lt", size, name, signed) - def eq(self, name: str, size: int, signed=False): + def eq(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: """Generate a StdEq cell.""" - self.prog.import_("primitives/binary_operators.futil") - return self.cell(name, ast.Stdlib.op("eq", size, signed)) + return self.binary("eq", size, name, signed) - def neq(self, name: str, size: int, signed=False): + def neq(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: """Generate a StdNeq cell.""" - self.prog.import_("primitives/binary_operators.futil") - return self.cell(name, ast.Stdlib.op("neq", size, signed)) + return self.binary("neq", size, name, signed) - def ge(self, name: str, size: int, signed=False): + def ge(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: """Generate a StdGe cell.""" - self.prog.import_("primitives/binary_operators.futil") - return self.cell(name, ast.Stdlib.op("ge", size, signed)) + return self.binary("ge", size, name, signed) - def le(self, name: str, size: int, signed=False): + def le(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: """Generate a StdLe cell.""" - self.prog.import_("primitives/binary_operators.futil") - return self.cell(name, ast.Stdlib.op("le", size, signed)) + return self.binary("le", size, name, signed) + + def logic(self, operation, size: int, name: str = None) -> CellBuilder: + """Generate a logical operator cell, of the flavor specified in `operation`.""" + name = name or self.generate_name(operation) + assert isinstance(name, str) + return self.cell(name, ast.Stdlib.op(operation, size, False)) - def and_(self, name: str, size: int) -> CellBuilder: + def and_(self, size: int, name: str = None) -> CellBuilder: """Generate a StdAnd cell.""" - return self.cell(name, ast.Stdlib.op("and", size, False)) + return self.logic("and", size, name) - def not_(self, name: str, size: int) -> CellBuilder: + def not_(self, size: int, name: str = None) -> CellBuilder: """Generate a StdNot cell.""" - return self.cell(name, ast.Stdlib.op("not", size, False)) + return self.logic("not", size, name) def pipelined_mult(self, name: str) -> CellBuilder: """Generate a pipelined multiplier.""" @@ -618,7 +642,7 @@ def port(self, name: str) -> ExprBuilder: return ExprBuilder(ast.Atom(ast.CompPort(self._cell.id, name))) def is_primitive(self, prim_name) -> bool: - """Check if the cell is an instance of the primitive {prim_name}.""" + """Check if the cell is an instance of the primitive `prim_name`.""" return ( isinstance(self._cell.comp, ast.CompInst) and self._cell.comp.id == prim_name @@ -632,6 +656,11 @@ def is_seq_mem_d1(self) -> bool: """Check if the cell is a SeqMemD1 cell.""" return self.is_primitive("seq_mem_d1") + @property + def name(self) -> str: + """Get the name of the cell.""" + return self._cell.id.name + @classmethod def unwrap_id(cls, obj): if isinstance(obj, cls): diff --git a/calyx-py/calyx/builder_util.py b/calyx-py/calyx/builder_util.py index e9b1a740da..03e83c646f 100644 --- a/calyx-py/calyx/builder_util.py +++ b/calyx-py/calyx/builder_util.py @@ -2,19 +2,20 @@ import calyx.builder as cb -def insert_comb_group(comp: cb.ComponentBuilder, left, right, cell, groupname): - """Accepts a cell that performs some computation on values {left} and {right}. - Creates a combinational group {groupname} that wires up the cell with these ports. +def insert_comb_group(comp: cb.ComponentBuilder, left, right, cell, groupname=None): + """Accepts a cell that performs some computation on values `left` and `right`. + Creates a combinational group that wires up the cell with these ports. Returns the cell and the combintational group. """ + groupname = groupname or f"{cell.name}_group" with comp.comb_group(groupname) as comb_group: cell.left = left cell.right = right return cell, comb_group -def insert_eq(comp: cb.ComponentBuilder, left, right, cellname, width): - """Inserts wiring into component {comp} to check if {left} == {right}. +def insert_eq(comp: cb.ComponentBuilder, left, right, width, cellname=None): + """Inserts wiring into component `comp` to check if `left` == `right`. = std_eq(); ... @@ -25,12 +26,11 @@ def insert_eq(comp: cb.ComponentBuilder, left, right, cellname, width): Returns handles to the cell and the combinational group. """ - eq_cell = comp.eq(cellname, width) - return insert_comb_group(comp, left, right, eq_cell, f"{cellname}_group") + return insert_comb_group(comp, left, right, comp.eq(width, cellname)) -def insert_neq(comp: cb.ComponentBuilder, left, right, cellname, width): - """Inserts wiring into component {comp} to check if {left} != {right}. +def insert_neq(comp: cb.ComponentBuilder, left, right, width, cellname=None): + """Inserts wiring into component `comp` to check if `left` != `right`. = std_neq(); ... @@ -41,12 +41,11 @@ def insert_neq(comp: cb.ComponentBuilder, left, right, cellname, width): Returns handles to the cell and the combinational group. """ - neq_cell = comp.neq(cellname, width) - return insert_comb_group(comp, left, right, neq_cell, f"{cellname}_group") + return insert_comb_group(comp, left, right, comp.neq(width, cellname)) -def insert_lt(comp: cb.ComponentBuilder, left, right, cellname, width): - """Inserts wiring into component {comp} to check if {left} < {right}. +def insert_lt(comp: cb.ComponentBuilder, left, right, width, cellname=None): + """Inserts wiring into component `comp` to check if `left` < `right`. = std_lt(); ... @@ -57,12 +56,11 @@ def insert_lt(comp: cb.ComponentBuilder, left, right, cellname, width): Returns handles to the cell and the combinational group. """ - lt_cell = comp.lt(cellname, width) - return insert_comb_group(comp, left, right, lt_cell, f"{cellname}_group") + return insert_comb_group(comp, left, right, comp.lt(width, cellname)) -def insert_le(comp: cb.ComponentBuilder, left, right, cellname, width): - """Inserts wiring into component {comp} to check if {left} <= {right}. +def insert_le(comp: cb.ComponentBuilder, left, right, width, cellname=None): + """Inserts wiring into component `comp` to check if `left` <= `right`. = std_le(); ... @@ -73,12 +71,11 @@ def insert_le(comp: cb.ComponentBuilder, left, right, cellname, width): Returns handles to the cell and the combinational group. """ - le_cell = comp.le(cellname, width) - return insert_comb_group(comp, left, right, le_cell, f"{cellname}_group") + return insert_comb_group(comp, left, right, comp.le(width, cellname)) -def insert_gt(comp: cb.ComponentBuilder, left, right, cellname, width): - """Inserts wiring into component {comp} to check if {left} > {right}. +def insert_gt(comp: cb.ComponentBuilder, left, right, width, cellname=None): + """Inserts wiring into component `comp` to check if `left` > `right`. = std_gt(); ... @@ -89,12 +86,11 @@ def insert_gt(comp: cb.ComponentBuilder, left, right, cellname, width): Returns handles to the cell and the combinational group. """ - gt_cell = comp.gt(cellname, width) - return insert_comb_group(comp, left, right, gt_cell, f"{cellname}_group") + return insert_comb_group(comp, left, right, comp.gt(width, cellname)) -def insert_add(comp: cb.ComponentBuilder, left, right, cellname, width): - """Inserts wiring into component {comp} to compute {left} + {right}. +def insert_add(comp: cb.ComponentBuilder, left, right, width, cellname=None): + """Inserts wiring into component `comp` to check if `left` > `right`. = std_add(); ... @@ -105,12 +101,11 @@ def insert_add(comp: cb.ComponentBuilder, left, right, cellname, width): Returns handles to the cell and the combinational group. """ - add_cell = comp.add(cellname, width) - return insert_comb_group(comp, left, right, add_cell, f"{cellname}_group") + return insert_comb_group(comp, left, right, comp.add(width, cellname)) -def insert_sub(comp: cb.ComponentBuilder, left, right, cellname, width): - """Inserts wiring into component {comp} to compute {left} - {right}. +def insert_sub(comp: cb.ComponentBuilder, left, right, width, cellname=None): + """Inserts wiring into component `comp` to check if `left` > `right`. = std_sub(); ... @@ -121,16 +116,15 @@ def insert_sub(comp: cb.ComponentBuilder, left, right, cellname, width): Returns handles to the cell and the combinational group. """ - sub_cell = comp.sub(cellname, width) - return insert_comb_group(comp, left, right, sub_cell, f"{cellname}_group") + return insert_comb_group(comp, left, right, comp.sub(width, cellname)) def insert_bitwise_flip_reg(comp: cb.ComponentBuilder, reg, cellname, width): - """Inserts wiring into component {comp} to bitwise-flip the contents of {reg}. + """Inserts wiring into component `comp` to bitwise-flip the contents of `reg`. Returns a handle to the group that does this. """ - not_cell = comp.not_(cellname, width) + not_cell = comp.not_(width, cellname) with comp.group(f"{cellname}_group") as not_group: not_cell.in_ = reg.out reg.write_en = 1 @@ -140,14 +134,14 @@ def insert_bitwise_flip_reg(comp: cb.ComponentBuilder, reg, cellname, width): def insert_incr(comp: cb.ComponentBuilder, reg, cellname, val=1): - """Inserts wiring into component {comp} to increment register {reg} by {val}. - 1. Within component {comp}, creates a group called {cellname}_group. - 2. Within the group, adds a cell {cellname} that computes sums. - 3. Puts the values {reg} and {val} into the cell. - 4. Then puts the answer of the computation back into {reg}. + """Inserts wiring into component `comp` to increment register `reg` by `val`. + 1. Within component `comp`, creates a group called `cellname`_group. + 2. Within the group, adds a cell `cellname` that computes sums. + 3. Puts the values `reg` and `val` into the cell. + 4. Then puts the answer of the computation back into `reg`. 5. Returns the group that does this. """ - add_cell = comp.add(cellname, 32) + add_cell = comp.add(32, cellname) with comp.group(f"{cellname}_group") as incr_group: add_cell.left = reg.out add_cell.right = cb.const(32, val) @@ -158,14 +152,14 @@ def insert_incr(comp: cb.ComponentBuilder, reg, cellname, val=1): def insert_decr(comp: cb.ComponentBuilder, reg, cellname, val=1): - """Inserts wiring into component {comp} to decrement register {reg} by {val}. - 1. Within component {comp}, creates a group called {cellname}_group. - 2. Within the group, adds a cell {cellname} that computes differences. - 3. Puts the values {reg} and {val} into the cell. - 4. Then puts the answer of the computation back into {reg}. + """Inserts wiring into component `comp` to decrement register `reg` by `val`. + 1. Within component `comp`, creates a group called `cellname`_group. + 2. Within the group, adds a cell `cellname` that computes differences. + 3. Puts the values `reg` and `val` into the cell. + 4. Then puts the answer of the computation back into `reg`. 5. Returns the group that does this. """ - sub_cell = comp.sub(cellname, 32) + sub_cell = comp.sub(32, cellname) with comp.group(f"{cellname}_group") as decr_group: sub_cell.left = reg.out sub_cell.right = cb.const(32, val) @@ -177,8 +171,8 @@ def insert_decr(comp: cb.ComponentBuilder, reg, cellname, val=1): def insert_reg_store(comp: cb.ComponentBuilder, reg, val, group): """Stores a value in a register. - 1. Within component {comp}, creates a group called {group}. - 2. Within {group}, sets the register {reg} to {val}. + 1. Within component `comp`, creates a group called `group`. + 2. Within `group`, sets the register `reg` to `val`. 3. Returns the group that does this. """ with comp.group(group) as reg_grp: @@ -190,9 +184,9 @@ def insert_reg_store(comp: cb.ComponentBuilder, reg, val, group): def mem_load_std_d1(comp: cb.ComponentBuilder, mem, i, reg, group): """Loads a value from one memory (std_d1) into a register. - 1. Within component {comp}, creates a group called {group}. - 2. Within {group}, reads from memory {mem} at address {i}. - 3. Writes the value into register {reg}. + 1. Within component `comp`, creates a group called `group`. + 2. Within `group`, reads from memory `mem` at address `i`. + 3. Writes the value into register `reg`. 4. Returns the group that does this. """ assert mem.is_std_mem_d1() @@ -206,9 +200,9 @@ def mem_load_std_d1(comp: cb.ComponentBuilder, mem, i, reg, group): def mem_store_std_d1(comp: cb.ComponentBuilder, mem, i, val, group): """Stores a value into a (std_d1) memory. - 1. Within component {comp}, creates a group called {group}. - 2. Within {group}, reads from {val}. - 3. Writes the value into memory {mem} at address i. + 1. Within component `comp`, creates a group called `group`. + 2. Within `group`, reads from `val`. + 3. Writes the value into memory `mem` at address `i`. 4. Returns the group that does this. """ assert mem.is_std_mem_d1() @@ -224,8 +218,8 @@ def mem_read_seq_d1(comp: cb.ComponentBuilder, mem, i, group): """Given a seq_mem_d1, reads from memory at address i. Note that this does not write the value anywhere. - 1. Within component {comp}, creates a group called {group}. - 2. Within {group}, reads from memory {mem} at address {i}, + 1. Within component `comp`, creates a group called `group`. + 2. Within `group`, reads from memory `mem` at address `i`, thereby "latching" the value. 3. Returns the group that does this. """ @@ -241,9 +235,9 @@ def mem_write_seq_d1_to_reg(comp: cb.ComponentBuilder, mem, reg, group): """Given a seq_mem_d1 that is already assumed to have a latched value, reads the latched value and writes it to a register. - 1. Within component {comp}, creates a group called {group}. - 2. Within {group}, reads from memory {mem}. - 3. Writes the value into register {reg}. + 1. Within component `comp`, creates a group called `group`. + 2. Within `group`, reads from memory `mem`. + 3. Writes the value into register `reg`. 4. Returns the group that does this. """ assert mem.is_seq_mem_d1() @@ -257,9 +251,9 @@ def mem_write_seq_d1_to_reg(comp: cb.ComponentBuilder, mem, reg, group): def mem_store_seq_d1(comp: cb.ComponentBuilder, mem, i, val, group): """Given a seq_mem_d1, stores a value into memory at address i. - 1. Within component {comp}, creates a group called {group}. - 2. Within {group}, reads from {val}. - 3. Writes the value into memory {mem} at address i. + 1. Within component `comp`, creates a group called `group`. + 2. Within `group`, reads from `val`. + 3. Writes the value into memory `mem` at address i. 4. Returns the group that does this. """ assert mem.is_seq_mem_d1() @@ -273,9 +267,9 @@ def mem_store_seq_d1(comp: cb.ComponentBuilder, mem, i, val, group): def insert_mem_load_to_mem(comp: cb.ComponentBuilder, mem, i, ans, j, group): """Loads a value from one std_mem_d1 memory into another. - 1. Within component {comp}, creates a group called {group}. - 2. Within {group}, reads from memory {mem} at address {i}. - 3. Writes the value into memory {ans} at address {j}. + 1. Within component `comp`, creates a group called `group`. + 2. Within `group`, reads from memory `mem` at address `i`. + 3. Writes the value into memory `ans` at address `j`. 4. Returns the group that does this. """ assert mem.is_std_mem_d1() and ans.is_std_mem_d1() @@ -295,15 +289,15 @@ def insert_add_store_in_reg( right, ans_reg=None, ): - """Inserts wiring into component {comp} to compute {left} + {right} and - store it in {ans_reg}. - 1. Within component {comp}, creates a group called {cellname}_group. - 2. Within {group}, create a cell {cellname} that computes sums. - 3. Puts the values of {left} and {right} into the cell. - 4. Then puts the answer of the computation into {ans_reg}. + """Inserts wiring into component `comp` to compute `left` + `right` and + store it in `ans_reg`. + 1. Within component `comp`, creates a group called `cellname`_group. + 2. Within `group`, create a cell `cellname` that computes sums. + 3. Puts the values of `left` and `right` into the cell. + 4. Then puts the answer of the computation into `ans_reg`. 4. Returns the summing group and the register. """ - add_cell = comp.add(cellname, 32) + add_cell = comp.add(32, cellname) ans_reg = ans_reg or comp.reg(f"reg_{cellname}", 32) with comp.group(f"{cellname}_group") as adder_group: add_cell.left = left @@ -322,15 +316,15 @@ def insert_sub_store_in_reg( width, ans_reg=None, ): - """Adds wiring into component {comp} to compute {left} - {right} - and store it in {ans_reg}. - 1. Within component {comp}, creates a group called {cellname}_group. - 2. Within {group}, create a cell {cellname} that computes differences. - 3. Puts the values of {left} and {right} into {cell}. - 4. Then puts the answer of the computation into {ans_reg}. + """Adds wiring into component `comp` to compute `left` - `right` + and store it in `ans_reg`. + 1. Within component `comp`, creates a group called `cellname`_group. + 2. Within `group`, create a cell `cellname` that computes differences. + 3. Puts the values of `left` and `right` into `cell`. + 4. Then puts the answer of the computation into `ans_reg`. 4. Returns the subtracting group and the register. """ - sub_cell = comp.sub(cellname, width) + sub_cell = comp.sub(width, cellname) ans_reg = ans_reg or comp.reg(f"reg_{cellname}", width) with comp.group(f"{cellname}_group") as sub_group: sub_cell.left = left diff --git a/calyx-py/calyx/queue_call.py b/calyx-py/calyx/queue_call.py index 5714e0227b..a0fe133a88 100644 --- a/calyx-py/calyx/queue_call.py +++ b/calyx-py/calyx/queue_call.py @@ -21,9 +21,7 @@ def insert_raise_err_if_i_eq_max_cmds(prog): i = raise_err_if_i_eq_max_cmds.input("i", 32) err = raise_err_if_i_eq_max_cmds.reg("err", 1, is_ref=True) - i_eq_max_cmds = util.insert_eq( - raise_err_if_i_eq_max_cmds, i, MAX_CMDS, "i_eq_MAX_CMDS", 32 - ) + i_eq_max_cmds = util.insert_eq(raise_err_if_i_eq_max_cmds, i, MAX_CMDS, 32) raise_err = util.insert_reg_store(raise_err_if_i_eq_max_cmds, err, 1, "raise_err") raise_err_if_i_eq_max_cmds.control += [ @@ -92,8 +90,8 @@ def insert_main(prog, queue): incr_i = util.insert_incr(main, i, "incr_i") # i++ incr_j = util.insert_incr(main, j, "incr_j") # j++ - err_eq_0 = util.insert_eq(main, err.out, 0, "err_eq_0", 1) # is `err` flag down? - cmd_le_1 = util.insert_le(main, cmd.out, 1, "cmd_le_1", 2) # cmd <= 1 + err_eq_0 = util.insert_eq(main, err.out, 0, 1) # is `err` flag down? + cmd_le_1 = util.insert_le(main, cmd.out, 1, 2) # cmd <= 1 read_cmd = util.mem_read_seq_d1(main, commands, i.out, "read_cmd_phase1") write_cmd_to_reg = util.mem_write_seq_d1_to_reg( diff --git a/calyx-py/test/builder_example.py b/calyx-py/test/builder_example.py index 3179940561..71488ea484 100644 --- a/calyx-py/test/builder_example.py +++ b/calyx-py/test/builder_example.py @@ -13,7 +13,7 @@ def add_main_component(prog): lhs = main.reg("lhs", 32) rhs = main.reg("rhs", 32) sum = main.reg("sum", 32) - add = main.add("add", 32) + add = main.add(32, "add") # ANCHOR_END: cells # ANCHOR: bare diff --git a/calyx-py/test/correctness/arbiter_6.py b/calyx-py/test/correctness/arbiter_6.py index 294f77856b..78fa3208e3 100644 --- a/calyx-py/test/correctness/arbiter_6.py +++ b/calyx-py/test/correctness/arbiter_6.py @@ -30,10 +30,10 @@ def add_wrap2(prog): j_mod_4 = wrap.reg("j_mod_4", 32) # Additional cells and groups to compute equality and lt - i_eq_0_cell, i_eq_0_grp = util.insert_eq(wrap, i, 0, "i_eq_0", 32) - i_eq_1_cell, i_eq_1_group = util.insert_eq(wrap, i, 1, "i_eq_1", 32) - j_lt_4_cell, j_lt_4_group = util.insert_lt(wrap, j, 4, "j_lt_4", 32) - j_lt_8_cell, j_lt_8_group = util.insert_lt(wrap, j, 8, "j_lt_8", 32) + i_eq_0_cell, i_eq_0_grp = util.insert_eq(wrap, i, 0, 32) + i_eq_1_cell, i_eq_1_group = util.insert_eq(wrap, i, 1, 32) + j_lt_4_cell, j_lt_4_group = util.insert_lt(wrap, j, 4, 32) + j_lt_8_cell, j_lt_8_group = util.insert_lt(wrap, j, 8, 32) # Load `j` unchanged into `j_mod_4`. unchanged = util.insert_reg_store(wrap, j_mod_4, j, "j_unchanged") @@ -126,10 +126,10 @@ def add_wrap3(prog): j_mod_4 = wrap.reg("j_mod_4", 32) # Additional cells to compute equality, and lt - i_eq_0_cell, i_eq_0_group = util.insert_eq(wrap, i, 0, "i_eq_0", 32) - i_eq_1_cell, i_eq_1_group = util.insert_eq(wrap, i, 1, "i_eq_1", 32) - i_eq_2_cell, i_eq_2_group = util.insert_eq(wrap, i, 2, "i_eq_2", 32) - j_lt_4_cell, j_lt_4_group = util.insert_lt(wrap, j, 4, "j_lt_4", 32) + i_eq_0_cell, i_eq_0_group = util.insert_eq(wrap, i, 0, 32) + i_eq_1_cell, i_eq_1_group = util.insert_eq(wrap, i, 1, 32) + i_eq_2_cell, i_eq_2_group = util.insert_eq(wrap, i, 2, 32) + j_lt_4_cell, j_lt_4_group = util.insert_lt(wrap, j, 4, 32) # Load `j` unchanged into `j_mod_4`. unchanged = util.insert_reg_store(wrap, j_mod_4, j, "j_unchanged") diff --git a/calyx-py/test/correctness/fifo.py b/calyx-py/test/correctness/fifo.py index c83e8345b9..c6d27983a1 100644 --- a/calyx-py/test/correctness/fifo.py +++ b/calyx-py/test/correctness/fifo.py @@ -36,21 +36,15 @@ def insert_fifo(prog, name): len = fifo.reg("len", 32) # The length of the FIFO. - # Cells and groups to compute equality. - cmd_eq_0 = util.insert_eq(fifo, cmd, 0, "cmd_eq_0", 2) - cmd_eq_1 = util.insert_eq(fifo, cmd, 1, "cmd_eq_1", 2) - cmd_eq_2 = util.insert_eq(fifo, cmd, 2, "cmd_eq_2", 2) - - write_eq_max_queue_len = util.insert_eq( - fifo, write.out, MAX_QUEUE_LEN, "write_eq_MAX_QUEUE_LEN", 32 - ) - read_eq_max_queue_len = util.insert_eq( - fifo, read.out, MAX_QUEUE_LEN, "read_eq_MAX_QUEUE_LEN", 32 - ) - len_eq_0 = util.insert_eq(fifo, len.out, 0, "len_eq_0", 32) - len_eq_max_queue_len = util.insert_eq( - fifo, len.out, MAX_QUEUE_LEN, "len_eq_MAX_QUEUE_LEN", 32 - ) + # Cells and groups to compute equality + cmd_eq_0 = util.insert_eq(fifo, cmd, 0, 2) + cmd_eq_1 = util.insert_eq(fifo, cmd, 1, 2) + cmd_eq_2 = util.insert_eq(fifo, cmd, 2, 2) + + write_eq_max_queue_len = util.insert_eq(fifo, write.out, MAX_QUEUE_LEN, 32) + read_eq_max_queue_len = util.insert_eq(fifo, read.out, MAX_QUEUE_LEN, 32) + len_eq_0 = util.insert_eq(fifo, len.out, 0, 32) + len_eq_max_queue_len = util.insert_eq(fifo, len.out, MAX_QUEUE_LEN, 32) # Cells and groups to increment read and write registers write_incr = util.insert_incr(fifo, write, "write_incr") # write++ diff --git a/calyx-py/test/correctness/pifo.py b/calyx-py/test/correctness/pifo.py index 8688deffd2..18623c647e 100644 --- a/calyx-py/test/correctness/pifo.py +++ b/calyx-py/test/correctness/pifo.py @@ -19,7 +19,7 @@ def insert_flow_inference(comp: cb.ComponentBuilder, cmd, flow, boundary, group) 4. Then puts the answer of the computation into {flow}. 5. Returns the group that does this. """ - cell = comp.lt("flow_inf", 32) + cell = comp.lt(32) with comp.group(group) as infer_flow_grp: cell.left = boundary cell.right = cmd @@ -111,19 +111,17 @@ def insert_pifo(prog, name, queue_l, queue_r, boundary): hot = pifo.reg("hot", 1) # Some equality checks. - hot_eq_0 = util.insert_eq(pifo, hot.out, 0, "hot_eq_0", 1) - hot_eq_1 = util.insert_eq(pifo, hot.out, 1, "hot_eq_1", 1) - flow_eq_0 = util.insert_eq(pifo, flow.out, 0, "flow_eq_0", 1) - flow_eq_1 = util.insert_eq(pifo, flow.out, 1, "flow_eq_1", 1) - len_eq_0 = util.insert_eq(pifo, len.out, 0, "len_eq_0", 32) - len_eq_max_queue_len = util.insert_eq( - pifo, len.out, MAX_QUEUE_LEN, "len_eq_MAX_QUEUE_LEN", 32 - ) - cmd_eq_0 = util.insert_eq(pifo, cmd, 0, "cmd_eq_0", 2) - cmd_eq_1 = util.insert_eq(pifo, cmd, 1, "cmd_eq_1", 2) - cmd_eq_2 = util.insert_eq(pifo, cmd, 2, "cmd_eq_2", 2) - err_eq_0 = util.insert_eq(pifo, err.out, 0, "err_eq_0", 1) - err_neq_0 = util.insert_neq(pifo, err.out, cb.const(1, 0), "err_neq_0", 1) + hot_eq_0 = util.insert_eq(pifo, hot.out, 0, 1) + hot_eq_1 = util.insert_eq(pifo, hot.out, 1, 1) + flow_eq_0 = util.insert_eq(pifo, flow.out, 0, 1) + flow_eq_1 = util.insert_eq(pifo, flow.out, 1, 1) + len_eq_0 = util.insert_eq(pifo, len.out, 0, 32) + len_eq_max_queue_len = util.insert_eq(pifo, len.out, MAX_QUEUE_LEN, 32) + cmd_eq_0 = util.insert_eq(pifo, cmd, 0, 2) + cmd_eq_1 = util.insert_eq(pifo, cmd, 1, 2) + cmd_eq_2 = util.insert_eq(pifo, cmd, 2, 2) + err_eq_0 = util.insert_eq(pifo, err.out, 0, 1) + err_neq_0 = util.insert_neq(pifo, err.out, cb.const(1, 0), 1) flip_hot = util.insert_bitwise_flip_reg(pifo, hot, "flip_hot", 1) raise_err = util.insert_reg_store(pifo, err, 1, "raise_err") # set `err` to 1 diff --git a/frontends/mrxl/mrxl/gen_calyx.py b/frontends/mrxl/mrxl/gen_calyx.py index 775ed9de0c..151e0f9b82 100644 --- a/frontends/mrxl/mrxl/gen_calyx.py +++ b/frontends/mrxl/mrxl/gen_calyx.py @@ -46,7 +46,7 @@ def incr_group(comp: cb.ComponentBuilder, idx: cb.CellBuilder, suffix: str) -> s """ # ANCHOR: incr_group group_name = f"incr_idx_{suffix}" - adder = comp.add(f"incr_{suffix}", 32) + adder = comp.add(32) with comp.group(group_name) as incr: adder.left = idx.out adder.right = 1 @@ -138,7 +138,7 @@ def expr_to_port(expr: ast.BaseExpr): if body.operation == "mul": operation = comp.cell(f"mul_{s_idx}", Stdlib.op("mult_pipe", 32, signed=False)) else: - operation = comp.add(f"add_{s_idx}", 32) + operation = comp.add(32) with comp.group(f"reduce{s_idx}") as evl: inp = comp.get_cell(f"{bind.src}_b0") inp.addr0 = idx.out diff --git a/frontends/mrxl/mrxl/map.py b/frontends/mrxl/mrxl/map.py index cad0faeb29..0a172b8d11 100644 --- a/frontends/mrxl/mrxl/map.py +++ b/frontends/mrxl/mrxl/map.py @@ -68,7 +68,7 @@ def expr_to_port(expr: ast.BaseExpr): f"mul_{suffix}", Stdlib.op("mult_pipe", 32, signed=False) ) else: - operation = comp.add(f"add_{suffix}", 32) + operation = comp.add(32, f"add_{suffix}") # ANCHOR_END: map_op assert ( diff --git a/frontends/systolic-lang/gen-systolic.py b/frontends/systolic-lang/gen-systolic.py index 6bcac49ecd..55631c1d1e 100755 --- a/frontends/systolic-lang/gen-systolic.py +++ b/frontends/systolic-lang/gen-systolic.py @@ -65,7 +65,7 @@ def pe(prog: cb.Builder, leaky_relu): mul = comp.pipelined_fp_smult("mul", BITWIDTH, INTWIDTH, FRACWIDTH) # No leaky relu means integer operations else: - add = comp.add("add", BITWIDTH) + add = comp.add(BITWIDTH, "add") # XXX: pipelined mult assumes 32 bit multiplication mul = comp.pipelined_mult("mul") @@ -118,7 +118,7 @@ def instantiate_indexor(comp: cb.ComponentBuilder, prefix, width) -> cb.CellBuil name = NAME_SCHEME["index name"].format(prefix=prefix) reg = comp.reg(name, width) - add = comp.add(f"{prefix}_add", width) + add = comp.add(width, f"{prefix}_add") init_name = NAME_SCHEME["index init"].format(prefix=prefix) with comp.static_group(init_name, 1): @@ -362,7 +362,7 @@ def try_build_calyx_add(comp, obj): if type(obj) == CalyxAdd: add_str = str(obj) if comp.try_get_cell(add_str) is None: - add = comp.add(add_str, BITWIDTH) + add = comp.add(BITWIDTH, add_str) with comp.static_group(add_str + "_group", 1): add.left = obj.port add.right = obj.const @@ -390,7 +390,7 @@ def instantiate_idx_cond_groups(comp: cb.ComponentBuilder, leaky_relu): and that sets cond_reg to idx + 1 < iter_limit """ idx = comp.reg("idx", BITWIDTH) - add = comp.add("idx_add", BITWIDTH) + add = comp.add(BITWIDTH, "idx_add") cond_reg = comp.reg("cond_reg", 1) with comp.static_group("init_idx", 1): idx.in_ = 0 @@ -408,7 +408,7 @@ def instantiate_idx_cond_groups(comp: cb.ComponentBuilder, leaky_relu): # operations are finished yet if not leaky_relu: iter_limit = comp.get_cell("iter_limit") - lt_iter_limit = comp.lt("lt_iter_limit", BITWIDTH) + lt_iter_limit = comp.lt(BITWIDTH, "lt_iter_limit") with comp.static_group("lt_iter_limit_group", 1): lt_iter_limit.left = add.out lt_iter_limit.right = iter_limit.out @@ -424,7 +424,7 @@ def init_dyn_vals(comp: cb.ComponentBuilder, depth_port, rem_iter_limit, leaky_r If leaky_relu, we do not need to check iteration limit. """ min_depth_4 = comp.reg("min_depth_4", BITWIDTH) - lt_depth_4 = comp.lt("lt_depth_4", BITWIDTH) + lt_depth_4 = comp.lt(BITWIDTH, "lt_depth_4") with comp.static_group("init_min_depth", 1): lt_depth_4.left = depth_port lt_depth_4.right = 4 @@ -433,7 +433,7 @@ def init_dyn_vals(comp: cb.ComponentBuilder, depth_port, rem_iter_limit, leaky_r min_depth_4.write_en = 1 if not leaky_relu: iter_limit = comp.reg("iter_limit", BITWIDTH) - iter_limit_add = comp.add("iter_limit_add", BITWIDTH) + iter_limit_add = comp.add(BITWIDTH, "iter_limit_add") with comp.static_group("init_iter_limit", 1): iter_limit_add.left = rem_iter_limit iter_limit_add.right = depth_port @@ -471,7 +471,7 @@ def instantiate_idx_between(comp: cb.ComponentBuilder, lo, hi) -> list: ge = ( comp.get_cell(index_ge) if comp.try_get_cell(index_ge) is not None - else comp.ge(index_ge, BITWIDTH) + else comp.ge(BITWIDTH, index_ge) ) with comp.static_group(group_str, 1): ge.left = idx_add.out @@ -481,7 +481,7 @@ def instantiate_idx_between(comp: cb.ComponentBuilder, lo, hi) -> list: lt = ( comp.get_cell(index_lt) if comp.try_get_cell(index_lt) is not None - else comp.lt(index_lt, BITWIDTH) + else comp.lt(BITWIDTH, index_lt) ) # if lo == 0, then only need to check if reg < hi if type(lo) == int and lo == 0: @@ -495,9 +495,9 @@ def instantiate_idx_between(comp: cb.ComponentBuilder, lo, hi) -> list: ge = ( comp.get_cell(index_ge) if comp.try_get_cell(index_ge) is not None - else comp.ge(index_ge, BITWIDTH) + else comp.ge(BITWIDTH, index_ge) ) - and_ = comp.and_(comb_str, 1) + and_ = comp.and_(1, comb_str) with comp.static_group(group_str, 1): ge.left = idx_add.out ge.right = lo_value @@ -564,7 +564,7 @@ def build_assignment( # either when a) value is positive or b) multiply operation has finished. go_next = comp.wire(f"relu_r{row}_go_next", BITWIDTH) # Increments idx_reg. - incr = comp.add(f"relu_r{row}_incr", BITWIDTH) + incr = comp.add(BITWIDTH, f"relu_r{row}_incr") # Performs multiplication for leaky relu. fp_mult = comp.fp_sop( f"relu_r{row}_val_mult", "mult_pipe", BITWIDTH, INTWIDTH, FRACWIDTH