Skip to content

Commit

Permalink
Merge branch 'main' into dev/hidetou/aig-comp
Browse files Browse the repository at this point in the history
  • Loading branch information
uenoku authored Dec 27, 2024
2 parents fb1c2a4 + 5b128a1 commit 333b1fe
Show file tree
Hide file tree
Showing 69 changed files with 1,544 additions and 482 deletions.
101 changes: 101 additions & 0 deletions frontends/PyCDE/integration_test/esi_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# REQUIRES: esi-runtime, esi-cosim, rtl-sim
# RUN: rm -rf %t
# RUN: mkdir %t && cd %t
# RUN: %PYTHON% %s %t 2>&1
# RUN: esi-cosim.py -- %PYTHON% %S/test_software/esi_advanced.py cosim env

import sys

from pycde import generator, Clock, Module, Reset, System
from pycde.bsp import get_bsp
from pycde.common import InputChannel, OutputChannel, Output
from pycde.types import Bits, UInt
from pycde import esi


class Merge(Module):
clk = Clock()
rst = Reset()
a = InputChannel(UInt(8))
b = InputChannel(UInt(8))

x = OutputChannel(UInt(8))

@generator
def build(ports):
chan = ports.a.type.merge(ports.a, ports.b)
ports.x = chan


class Join(Module):
clk = Clock()
rst = Reset()
a = InputChannel(UInt(8))
b = InputChannel(UInt(8))

x = OutputChannel(UInt(9))

@generator
def build(ports):
joined = ports.a.type.join(ports.a, ports.b)
ports.x = joined.transform(lambda x: x.a + x.b)


class Fork(Module):
clk = Clock()
rst = Reset()
a = InputChannel(UInt(8))

x = OutputChannel(UInt(8))
y = OutputChannel(UInt(8))

@generator
def build(ports):
x, y = ports.a.fork(ports.clk, ports.rst)
ports.x = x
ports.y = y


class Top(Module):
clk = Clock()
rst = Reset()

@generator
def build(ports):
clk = ports.clk
rst = ports.rst
merge_a = esi.ChannelService.from_host(esi.AppID("merge_a"),
UInt(8)).buffer(clk, rst, 1)
merge_b = esi.ChannelService.from_host(esi.AppID("merge_b"),
UInt(8)).buffer(clk, rst, 1)
merge = Merge("merge_i8",
clk=ports.clk,
rst=ports.rst,
a=merge_a,
b=merge_b)
esi.ChannelService.to_host(esi.AppID("merge_x"),
merge.x.buffer(clk, rst, 1))

join_a = esi.ChannelService.from_host(esi.AppID("join_a"),
UInt(8)).buffer(clk, rst, 1)
join_b = esi.ChannelService.from_host(esi.AppID("join_b"),
UInt(8)).buffer(clk, rst, 1)
join = Join("join_i8", clk=ports.clk, rst=ports.rst, a=join_a, b=join_b)
esi.ChannelService.to_host(
esi.AppID("join_x"),
join.x.buffer(clk, rst, 1).transform(lambda x: x.as_uint(16)))

fork_a = esi.ChannelService.from_host(esi.AppID("fork_a"),
UInt(8)).buffer(clk, rst, 1)
fork = Fork("fork_i8", clk=ports.clk, rst=ports.rst, a=fork_a)
esi.ChannelService.to_host(esi.AppID("fork_x"), fork.x.buffer(clk, rst, 1))
esi.ChannelService.to_host(esi.AppID("fork_y"), fork.y.buffer(clk, rst, 1))


if __name__ == "__main__":
bsp = get_bsp(sys.argv[2] if len(sys.argv) > 2 else None)
s = System(bsp(Top), name="ESIAdvanced", output_directory=sys.argv[1])
s.generate()
s.run_passes()
s.compile()
s.package()
53 changes: 53 additions & 0 deletions frontends/PyCDE/integration_test/test_software/esi_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import esiaccel as esi
import sys

platform = sys.argv[1]
acc = esi.AcceleratorConnection(platform, sys.argv[2])

d = acc.build_accelerator()

merge_a = d.ports[esi.AppID("merge_a")].write_port("data")
merge_a.connect()
merge_b = d.ports[esi.AppID("merge_b")].write_port("data")
merge_b.connect()
merge_x = d.ports[esi.AppID("merge_x")].read_port("data")
merge_x.connect()

for i in range(10, 15):
merge_a.write(i)
merge_b.write(i + 10)
x1 = merge_x.read()
x2 = merge_x.read()
print(f"merge_a: {i}, merge_b: {i + 10}, "
f"merge_x 1: {x1}, merge_x 2: {x2}")
assert x1 == i + 10 or x1 == i
assert x2 == i + 10 or x2 == i
assert x1 != x2

join_a = d.ports[esi.AppID("join_a")].write_port("data")
join_a.connect()
join_b = d.ports[esi.AppID("join_b")].write_port("data")
join_b.connect()
join_x = d.ports[esi.AppID("join_x")].read_port("data")
join_x.connect()

for i in range(15, 27):
join_a.write(i)
join_b.write(i + 10)
x = join_x.read()
print(f"join_a: {i}, join_b: {i + 10}, join_x: {x}")
assert x == (i + i + 10) & 0xFFFF

fork_a = d.ports[esi.AppID("fork_a")].write_port("data")
fork_a.connect()
fork_x = d.ports[esi.AppID("fork_x")].read_port("data")
fork_x.connect()
fork_y = d.ports[esi.AppID("fork_y")].read_port("data")
fork_y.connect()

for i in range(27, 33):
fork_a.write(i)
x = fork_x.read()
y = fork_y.read()
print(f"fork_a: {i}, fork_x: {x}, fork_y: {y}")
assert x == y
3 changes: 2 additions & 1 deletion frontends/PyCDE/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ requires = [

# MLIR build depends.
"numpy",
"pybind11>=2.9",
"pybind11>=2.11,<=2.12",
"nanobind==2.4.0",
"PyYAML",

# PyCDE depends
Expand Down
2 changes: 2 additions & 0 deletions frontends/PyCDE/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def run(self):
if "BUILD_TYPE" in os.environ:
cfg = os.environ["BUILD_TYPE"]
cmake_args = [
"-Wno-dev",
"-GNinja",
"-DCMAKE_INSTALL_PREFIX={}".format(os.path.abspath(cmake_install_dir)),
"-DPython3_EXECUTABLE={}".format(sys.executable.replace("\\", "/")),
"-DCMAKE_BUILD_TYPE={}".format(cfg), # not used on MSVC, but no harm
Expand Down
19 changes: 7 additions & 12 deletions frontends/PyCDE/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ add_mlir_python_modules(PyCDE
PyCDE_CIRCTPythonCAPI
)

install(TARGETS PyCDE_CIRCTPythonCAPI
DESTINATION python_packages/pycde/circt/_mlir_libs
RUNTIME_DEPENDENCIES
PRE_EXCLUDE_REGEXES ".*"
PRE_INCLUDE_REGEXES ".*zlib.*"
COMPONENT PyCDE
)
add_dependencies(PyCDE PyCDE_CIRCTPythonModules)
add_dependencies(install-PyCDE install-PyCDE_CIRCTPythonModules)

Expand All @@ -103,15 +110,3 @@ install(FILES ${esiprims}
DESTINATION python_packages/pycde
COMPONENT PyCDE
)

install(IMPORTED_RUNTIME_ARTIFACTS PyCDE_CIRCTPythonCAPI
RUNTIME_DEPENDENCY_SET PyCDE_RUNTIME_DEPS
DESTINATION python_packages/pycde/circt/_mlir_libs
COMPONENT PyCDE
)
install(RUNTIME_DEPENDENCY_SET PyCDE_RUNTIME_DEPS
DESTINATION python_packages/pycde/circt/_mlir_libs
PRE_EXCLUDE_REGEXES .*
PRE_INCLUDE_REGEXES zlib1
COMPONENT PyCDE
)
14 changes: 14 additions & 0 deletions frontends/PyCDE/src/pycde/bsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,19 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional

from .cosim import CosimBSP
from .xrt import XrtBSP


def get_bsp(name: Optional[str] = None):
if name is None or name == "cosim":
return CosimBSP
elif name == "xrt":
return XrtBSP
elif name == "xrt_cosim":
from .xrt import XrtCosimBSP
return XrtCosimBSP
else:
raise ValueError(f"Unknown bsp type: {name}")
32 changes: 27 additions & 5 deletions frontends/PyCDE/src/pycde/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Tuple

from .module import Module, ModuleLikeBuilderBase, PortError
from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal
from .signals import (BitsSignal, ChannelSignal, ClockSignal, Signal,
_FromCirctValue)
from .system import System
from .support import get_user_loc, obj_to_typed_attribute
from .types import Channel
from .support import clog2, get_user_loc
from .types import Bits, Channel

from .circt.dialects import handshake as raw_handshake
from .circt import ir
Expand Down Expand Up @@ -82,7 +83,7 @@ def instantiate(self, module_inst, inputs, instance_name: str):
# If the input is a channel signal, the types must match.
if signal.type.inner_type != port.type:
raise ValueError(
f"Wrong type on input signal '{name}'. Got '{signal.type}',"
f"Wrong type on input signal '{name}'. Got '{signal.type.inner_type}',"
f" expected '{port.type}'")
assert port.idx is not None
circt_inputs[port.idx] = signal.value
Expand Down Expand Up @@ -124,3 +125,24 @@ class Func(Module):

BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder
_builder: FuncBuilder


def demux(cond: BitsSignal, data: Signal) -> Tuple[Signal, Signal]:
"""Demux a signal based on a condition."""
condbr = raw_handshake.ConditionalBranchOp(cond.value, data.value)
return (_FromCirctValue(condbr.trueResult),
_FromCirctValue(condbr.falseResult))


def cmerge(*args: Signal) -> Tuple[Signal, BitsSignal]:
"""Merge multiple signals into one and the index of the signal."""
if len(args) == 0:
raise ValueError("cmerge must have at least one argument")
first = args[0]
for a in args[1:]:
if a.type != first.type:
raise ValueError("All arguments to cmerge must have the same type")
idx_type = Bits(clog2(len(args)))
cm = raw_handshake.ControlMergeOp(a.type._type, idx_type._type,
[a.value for a in args])
return (_FromCirctValue(cm.result), BitsSignal(cm.index, idx_type))
18 changes: 18 additions & 0 deletions frontends/PyCDE/src/pycde/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def name(self, new: str):
else:
self._name = new

def get_name(self, default: str = "") -> str:
return self.name if self.name is not None else default

@property
def appid(self) -> Optional[object]: # Optional AppID.
from .module import AppID
Expand Down Expand Up @@ -752,6 +755,21 @@ def transform(self, transform: Callable[[Signal], Signal]) -> ChannelSignal:
ready_wire.assign(ready)
return ret_chan

def fork(self, clk, rst) -> Tuple[ChannelSignal, ChannelSignal]:
"""Fork the channel into two channels, returning the two new channels."""
from .constructs import Wire
from .types import Bits
both_ready = Wire(Bits(1))
both_ready.name = self.get_name() + "_fork_both_ready"
data, valid = self.unwrap(both_ready)
valid_gate = both_ready & valid
a, a_rdy = self.type.wrap(data, valid_gate)
b, b_rdy = self.type.wrap(data, valid_gate)
abuf = a.buffer(clk, rst, 1)
bbuf = b.buffer(clk, rst, 1)
both_ready.assign(a_rdy & b_rdy)
return abuf, bbuf


class BundleSignal(Signal):
"""Signal for types.Bundle."""
Expand Down
3 changes: 1 addition & 2 deletions frontends/PyCDE/src/pycde/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def get_instance(self,
# Then run all the passes to lower dialects which produce `hw.module`s.
"builtin.module(lower-handshake-to-dc)",
"builtin.module(dc-materialize-forks-sinks)",
"builtin.module(canonicalize)",
"builtin.module(lower-dc-to-hw)",
"builtin.module(map-arith-to-comb)",

# Run ESI manifest passes.
"builtin.module(esi-appid-hier{{top={tops} }}, esi-build-manifest{{top={tops} }})",
Expand All @@ -275,7 +275,6 @@ def get_instance(self,
# Instaniate hlmems, which could produce new esi connections.
"builtin.module(hw.module(lower-seq-hlmem))",
"builtin.module(lower-esi-to-physical)",
# TODO: support more than just cosim.
"builtin.module(lower-esi-bundles, lower-esi-ports)",
"builtin.module(lower-esi-to-hw{{platform={platform}}})",
"builtin.module(convert-fsm-to-sv)",
Expand Down
2 changes: 1 addition & 1 deletion frontends/PyCDE/src/pycde/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def extra_compile_args(self, pycde_system: System):
# lives easier and create a minimum timescale through the command-line.
cmd_file = os.path.join(pycde_system.output_directory, "cmds.f")
with open(cmd_file, "w+") as f:
f.write("+timescale+1ns/1ps")
f.write("+timescale+1ns/1ps\n")

return [f"-f{cmd_file}"]

Expand Down
Loading

0 comments on commit 333b1fe

Please sign in to comment.