-
Notifications
You must be signed in to change notification settings - Fork 305
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyCDE] Add fork, join, and merge channel functions (#8011)
- <ChannelSignal>.fork creates two new channels, waits until they are both available, then accepts an input. Also buffer the output channels to avoid combinational loops. - Channel.join waits on two channels then creates a message on the one output channel containing a struct with field 'a' equal to input channel A's content and likewise for channel B. - Channel.merge funnels two channels together into a single output stream. This is functionality which really should be handled by the DC dialect but it's not ready for primetime.
- Loading branch information
Showing
6 changed files
with
332 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# REQUIRES: esi-runtime, esi-cosim, rtl-sim | ||
# RUN: rm -rf %t | ||
# RUN: mkdir %t && cd %t | ||
# RUN: %PYTHON% %s %t 2>&1 | ||
# RUN: esi-cosim.py -- %PYTHON% %S/test_software/esi_advanced.py cosim env | ||
|
||
import sys | ||
|
||
from pycde import generator, Clock, Module, Reset, System | ||
from pycde.bsp import get_bsp | ||
from pycde.common import InputChannel, OutputChannel, Output | ||
from pycde.types import Bits, UInt | ||
from pycde import esi | ||
|
||
|
||
class Merge(Module): | ||
clk = Clock() | ||
rst = Reset() | ||
a = InputChannel(UInt(8)) | ||
b = InputChannel(UInt(8)) | ||
|
||
x = OutputChannel(UInt(8)) | ||
|
||
@generator | ||
def build(ports): | ||
chan = ports.a.type.merge(ports.a, ports.b) | ||
ports.x = chan | ||
|
||
|
||
class Join(Module): | ||
clk = Clock() | ||
rst = Reset() | ||
a = InputChannel(UInt(8)) | ||
b = InputChannel(UInt(8)) | ||
|
||
x = OutputChannel(UInt(9)) | ||
|
||
@generator | ||
def build(ports): | ||
joined = ports.a.type.join(ports.a, ports.b) | ||
ports.x = joined.transform(lambda x: x.a + x.b) | ||
|
||
|
||
class Fork(Module): | ||
clk = Clock() | ||
rst = Reset() | ||
a = InputChannel(UInt(8)) | ||
|
||
x = OutputChannel(UInt(8)) | ||
y = OutputChannel(UInt(8)) | ||
|
||
@generator | ||
def build(ports): | ||
x, y = ports.a.fork(ports.clk, ports.rst) | ||
ports.x = x | ||
ports.y = y | ||
|
||
|
||
class Top(Module): | ||
clk = Clock() | ||
rst = Reset() | ||
|
||
@generator | ||
def build(ports): | ||
clk = ports.clk | ||
rst = ports.rst | ||
merge_a = esi.ChannelService.from_host(esi.AppID("merge_a"), | ||
UInt(8)).buffer(clk, rst, 1) | ||
merge_b = esi.ChannelService.from_host(esi.AppID("merge_b"), | ||
UInt(8)).buffer(clk, rst, 1) | ||
merge = Merge("merge_i8", | ||
clk=ports.clk, | ||
rst=ports.rst, | ||
a=merge_a, | ||
b=merge_b) | ||
esi.ChannelService.to_host(esi.AppID("merge_x"), | ||
merge.x.buffer(clk, rst, 1)) | ||
|
||
join_a = esi.ChannelService.from_host(esi.AppID("join_a"), | ||
UInt(8)).buffer(clk, rst, 1) | ||
join_b = esi.ChannelService.from_host(esi.AppID("join_b"), | ||
UInt(8)).buffer(clk, rst, 1) | ||
join = Join("join_i8", clk=ports.clk, rst=ports.rst, a=join_a, b=join_b) | ||
esi.ChannelService.to_host( | ||
esi.AppID("join_x"), | ||
join.x.buffer(clk, rst, 1).transform(lambda x: x.as_uint(16))) | ||
|
||
fork_a = esi.ChannelService.from_host(esi.AppID("fork_a"), | ||
UInt(8)).buffer(clk, rst, 1) | ||
fork = Fork("fork_i8", clk=ports.clk, rst=ports.rst, a=fork_a) | ||
esi.ChannelService.to_host(esi.AppID("fork_x"), fork.x.buffer(clk, rst, 1)) | ||
esi.ChannelService.to_host(esi.AppID("fork_y"), fork.y.buffer(clk, rst, 1)) | ||
|
||
|
||
if __name__ == "__main__": | ||
bsp = get_bsp(sys.argv[2] if len(sys.argv) > 2 else None) | ||
s = System(bsp(Top), name="ESIAdvanced", output_directory=sys.argv[1]) | ||
s.generate() | ||
s.run_passes() | ||
s.compile() | ||
s.package() |
53 changes: 53 additions & 0 deletions
53
frontends/PyCDE/integration_test/test_software/esi_advanced.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import esiaccel as esi | ||
import sys | ||
|
||
platform = sys.argv[1] | ||
acc = esi.AcceleratorConnection(platform, sys.argv[2]) | ||
|
||
d = acc.build_accelerator() | ||
|
||
merge_a = d.ports[esi.AppID("merge_a")].write_port("data") | ||
merge_a.connect() | ||
merge_b = d.ports[esi.AppID("merge_b")].write_port("data") | ||
merge_b.connect() | ||
merge_x = d.ports[esi.AppID("merge_x")].read_port("data") | ||
merge_x.connect() | ||
|
||
for i in range(10, 15): | ||
merge_a.write(i) | ||
merge_b.write(i + 10) | ||
x1 = merge_x.read() | ||
x2 = merge_x.read() | ||
print(f"merge_a: {i}, merge_b: {i + 10}, " | ||
f"merge_x 1: {x1}, merge_x 2: {x2}") | ||
assert x1 == i + 10 or x1 == i | ||
assert x2 == i + 10 or x2 == i | ||
assert x1 != x2 | ||
|
||
join_a = d.ports[esi.AppID("join_a")].write_port("data") | ||
join_a.connect() | ||
join_b = d.ports[esi.AppID("join_b")].write_port("data") | ||
join_b.connect() | ||
join_x = d.ports[esi.AppID("join_x")].read_port("data") | ||
join_x.connect() | ||
|
||
for i in range(15, 27): | ||
join_a.write(i) | ||
join_b.write(i + 10) | ||
x = join_x.read() | ||
print(f"join_a: {i}, join_b: {i + 10}, join_x: {x}") | ||
assert x == (i + i + 10) & 0xFFFF | ||
|
||
fork_a = d.ports[esi.AppID("fork_a")].write_port("data") | ||
fork_a.connect() | ||
fork_x = d.ports[esi.AppID("fork_x")].read_port("data") | ||
fork_x.connect() | ||
fork_y = d.ports[esi.AppID("fork_y")].read_port("data") | ||
fork_y.connect() | ||
|
||
for i in range(27, 33): | ||
fork_a.write(i) | ||
x = fork_x.read() | ||
y = fork_y.read() | ||
print(f"fork_a: {i}, fork_x: {x}, fork_y: {y}") | ||
assert x == y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# RUN: %PYTHON% %s | FileCheck %s | ||
|
||
from pycde import generator, Clock, Module, Reset | ||
from pycde.common import InputChannel, OutputChannel | ||
from pycde.testing import unittestmodule | ||
from pycde.types import Bits, UInt | ||
|
||
# CHECK-LABEL: hw.module @Merge(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, in %b : !esi.channel<i8>, out x : !esi.channel<i8>) | ||
# CHECK-NEXT: %rawOutput, %valid = esi.unwrap.vr %a, [[R1:%.+]] : i8 | ||
# CHECK-NEXT: %rawOutput_0, %valid_1 = esi.unwrap.vr %b, [[R2:%.+]] : i8 | ||
# CHECK-NEXT: %true = hw.constant true | ||
# CHECK-NEXT: [[R0:%.+]] = comb.xor bin %valid, %true : i1 | ||
# CHECK-NEXT: [[R1]] = comb.and bin %valid, %ready : i1 | ||
# CHECK-NEXT: [[R2]] = comb.and bin [[R0]], %ready : i1 | ||
# CHECK-NEXT: [[R3:%.+]] = comb.and bin %valid, %valid : i1 | ||
# CHECK-NEXT: [[R4:%.+]] = comb.and bin [[R0]], %valid_1 : i1 | ||
# CHECK-NEXT: [[R5:%.+]] = comb.or bin [[R3]], [[R4]] : i1 | ||
# CHECK-NEXT: [[R6:%.+]] = comb.mux bin %valid, %rawOutput, %rawOutput_0 | ||
# CHECK-NEXT: %chanOutput, %ready = esi.wrap.vr [[R6]], [[R5]] : i8 | ||
# CHECK-NEXT: hw.output %chanOutput : !esi.channel<i8> | ||
|
||
|
||
@unittestmodule() | ||
class Merge(Module): | ||
clk = Clock() | ||
rst = Reset() | ||
a = InputChannel(Bits(8)) | ||
b = InputChannel(Bits(8)) | ||
|
||
x = OutputChannel(Bits(8)) | ||
|
||
@generator | ||
def build(ports): | ||
chan = ports.a.type.merge(ports.a, ports.b) | ||
ports.x = chan | ||
|
||
|
||
# CHECK-LABEL: hw.module @Join(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<ui8>, in %b : !esi.channel<ui8>, out x : !esi.channel<ui9>) | ||
# CHECK-NEXT: %rawOutput, %valid = esi.unwrap.vr %a, [[R2:%.+]] : ui8 | ||
# CHECK-NEXT: %rawOutput_0, %valid_1 = esi.unwrap.vr %b, [[R2]] : ui8 | ||
# CHECK-NEXT: [[R0:%.+]] = comb.and bin %valid, %valid_1 : i1 | ||
# CHECK-NEXT: [[R1:%.+]] = hw.struct_create (%rawOutput, %rawOutput_0) : !hw.struct<a: ui8, b: ui8> | ||
# CHECK-NEXT: %chanOutput, %ready = esi.wrap.vr [[R1]], [[R0]] : !hw.struct<a: ui8, b: ui8> | ||
# CHECK-NEXT: [[R2]] = comb.and bin %ready, [[R0]] : i1 | ||
# CHECK-NEXT: %rawOutput_2, %valid_3 = esi.unwrap.vr %chanOutput, %ready_7 : !hw.struct<a: ui8, b: ui8> | ||
# CHECK-NEXT: %a_4 = hw.struct_extract %rawOutput_2["a"] : !hw.struct<a: ui8, b: ui8> | ||
# CHECK-NEXT: %b_5 = hw.struct_extract %rawOutput_2["b"] : !hw.struct<a: ui8, b: ui8> | ||
# CHECK-NEXT: [[R3:%.+]] = hwarith.add %a_4, %b_5 : (ui8, ui8) -> ui9 | ||
# CHECK-NEXT: %chanOutput_6, %ready_7 = esi.wrap.vr [[R3]], %valid_3 : ui9 | ||
# CHECK-NEXT: hw.output %chanOutput_6 : !esi.channel<ui9> | ||
@unittestmodule(run_passes=True, emit_outputs=True) | ||
class Join(Module): | ||
clk = Clock() | ||
rst = Reset() | ||
a = InputChannel(UInt(8)) | ||
b = InputChannel(UInt(8)) | ||
|
||
x = OutputChannel(UInt(9)) | ||
|
||
@generator | ||
def build(ports): | ||
joined = ports.a.type.join(ports.a, ports.b) | ||
ports.x = joined.transform(lambda x: x.a + x.b) | ||
|
||
|
||
# CHECK-LABEL: hw.module @Fork(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<ui8>, out x : !esi.channel<ui8>, out y : !esi.channel<ui8>) | ||
# CHECK-NEXT: %rawOutput, %valid = esi.unwrap.vr %a, [[R3:%.+]] : ui8 | ||
# CHECK-NEXT: [[R0:%.+]] = comb.and bin [[R3]], %valid : i1 | ||
# CHECK-NEXT: %chanOutput, %ready = esi.wrap.vr %rawOutput, [[R0]] : ui8 | ||
# CHECK-NEXT: %chanOutput_0, %ready_1 = esi.wrap.vr %rawOutput, [[R0]] : ui8 | ||
# CHECK-NEXT: [[R1:%.+]] = esi.buffer %clk, %rst, %chanOutput {stages = 1 : i64} : ui8 | ||
# CHECK-NEXT: [[R2:%.+]] = esi.buffer %clk, %rst, %chanOutput_0 {stages = 1 : i64} : ui8 | ||
# CHECK-NEXT: [[R3]] = comb.and bin %ready, %ready_1 : i1 | ||
# CHECK-NEXT: hw.output [[R1]], [[R2]] : !esi.channel<ui8>, !esi.channel<ui8> | ||
@unittestmodule(run_passes=True, emit_outputs=True) | ||
class Fork(Module): | ||
clk = Clock() | ||
rst = Reset() | ||
a = InputChannel(UInt(8)) | ||
|
||
x = OutputChannel(UInt(8)) | ||
y = OutputChannel(UInt(8)) | ||
|
||
@generator | ||
def build(ports): | ||
x, y = ports.a.fork(ports.clk, ports.rst) | ||
ports.x = x | ||
ports.y = y |