Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calyx Binary Floating Point AddF Operator #7089

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions include/circt/Dialect/Calyx/CalyxPrimitives.td
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,49 @@ def AndLibOp : CombinationalArithBinaryLibraryOp<"and"> {}
def OrLibOp : CombinationalArithBinaryLibraryOp<"or"> {}
def XorLibOp : CombinationalArithBinaryLibraryOp<"xor"> {}

class ArithBinaryFloatingPointLibraryOp<string mnemonic> : ArithBinaryLibraryOp<mnemonic, [
SameTypeConstraint<"left", "out">]> {}

def AddFNOp : ArithBinaryFloatingPointLibraryOp<"addFN"> {
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control, I1:$subOp,
AnyFloat:$left, AnyFloat:$right, AnySignlessInteger:$roundingMode, AnyFloat:$out,
AnySignlessInteger:$exceptionalFlags, I1:$done);

let extraClassDefinition = [{
SmallVector<StringRef> $cppClass::portNames() {
return {clkPort, resetPort, goPort, "control", "subOp",
"left", "right", "roundingMode", "out", "exceptionalFlags", donePort
};
}

SmallVector<Direction> $cppClass::portDirections() {
return {Input, Input, Input, Input, Input, Input, Input, Input, Output, Output, Output};
}

void $cppClass::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
getCellAsmResultNames(setNameFn, *this, this->portNames());
}

bool $cppClass::isCombinational() { return false; }

SmallVector<DictionaryAttr> $cppClass::portAttributes() {
IntegerAttr isSet = IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
NamedAttrList go, clk, reset, done;
go.append(goPort, isSet);
clk.append(clkPort, isSet);
reset.append(resetPort, isSet);
done.append(donePort, isSet);
return {clk.getDictionary(getContext()), reset.getDictionary(getContext()),
go.getDictionary(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), done.getDictionary(getContext()),
DictionaryAttr::get(getContext())
};
}
}];
}

def MuxLibOp : CalyxLibraryOp<"mux", [
Combinational, SameTypeConstraint<"tru", "fal">, SameTypeConstraint<"tru", "out">
]> {
Expand Down
37 changes: 35 additions & 2 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"

#include <variant>

Expand Down Expand Up @@ -281,6 +282,9 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
/// floating point
AddFOp,
/// others
SelectOp, IndexCastOp, CallOp>(
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
.template Case<FuncOp, scf::ConditionOp>([&](auto) {
Expand Down Expand Up @@ -314,6 +318,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, DivSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, RemUIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, AddFOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op) const;
Expand Down Expand Up @@ -409,7 +414,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
// Pass the result from the Operation to the Calyx primitive.
op.getResult().replaceAllUsesWith(out);
auto reg = createRegister(
op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(),
op.getLoc(), rewriter, getComponent(), width,
getState<ComponentLoweringState>().getUniqueName(opName));
// Operation pipelines are not combinational, so a GroupOp is required.
auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
Expand All @@ -434,6 +439,19 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
// The group is done when the register write is complete.
rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());

if (isa<calyx::AddFNOp>(opPipe)) {
auto opFN = cast<calyx::AddFNOp>(opPipe);
hw::ConstantOp subOp;
if (isa<arith::AddFOp>(op)) {
subOp = createConstant(loc, rewriter, getComponent(), /*width=*/1,
/*subtract=*/0);
} else {
subOp = createConstant(loc, rewriter, getComponent(), /*width=*/1,
/*subtract=*/1);
}
rewriter.create<calyx::AssignOp>(loc, opFN.getSubOp(), subOp);
}

// Register the values for the pipeline.
getState<ComponentLoweringState>().registerEvaluatingGroup(out, group);
getState<ComponentLoweringState>().registerEvaluatingGroup(opPipe.getLeft(),
Expand Down Expand Up @@ -666,6 +684,21 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
/*out=*/remPipe.getOut());
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
AddFOp addf) const {
Location loc = addf.getLoc();
Type width = addf.getResult().getType();
IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
five = rewriter.getIntegerType(5);
auto addFN =
getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::AddFNOp>(
rewriter, loc,
{one, one, one, one, one, width, width, three, width, five, one});
return buildLibraryBinaryPipeOp<calyx::AddFNOp>(rewriter, addf, addFN,
addFN.getOut());
}

template <typename TAllocOp>
static LogicalResult buildAllocOp(ComponentLoweringState &componentState,
PatternRewriter &rewriter, TAllocOp allocOp) {
Expand Down Expand Up @@ -1868,7 +1901,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
ExtSIOp, CallOp>();
ExtSIOp, CallOp, AddFOp>();

RewritePatternSet legalizePatterns(&getContext());
legalizePatterns.add<DummyPattern>(&getContext());
Expand Down
45 changes: 45 additions & 0 deletions lib/Dialect/Calyx/Export/CalyxEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ struct ImportTracker {
static constexpr std::string_view sFloat = "float";
return {sFloat};
})
.Case<AddFNOp>([&](auto op) -> FailureOr<StringRef> {
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
static constexpr std::string_view sFloatingPoint = "float/addFN";
return {sFloatingPoint};
})
.Default([&](auto op) {
auto diag = op->emitOpError() << "not supported for emission";
return diag;
Expand Down Expand Up @@ -288,6 +292,9 @@ struct Emitter {
void emitLibraryPrimTypedByFirstOutputPort(
Operation *op, std::optional<StringRef> calyxLibName = {});

// Emits a library floating point primitives
void emitLibraryFloatingPoint(Operation *op);

private:
/// Used to track which imports are required for this program.
ImportTracker importTracker;
Expand Down Expand Up @@ -668,6 +675,7 @@ void Emitter::emitComponent(ComponentInterface op) {
emitLibraryPrimTypedByFirstOutputPort(
op, /*calyxLibName=*/{"std_sdiv_pipe"});
})
.Case<AddFNOp>([&](auto op) { emitLibraryFloatingPoint(op); })
.Default([&](auto op) {
emitOpError(op, "not supported for emission inside component");
});
Expand Down Expand Up @@ -964,6 +972,43 @@ void Emitter::emitLibraryPrimTypedByFirstOutputPort(
<< LParen() << bitWidth << RParen() << semicolonEndL();
}

void Emitter::emitLibraryFloatingPoint(Operation *op) {
auto cell = cast<CellInterface>(op);
unsigned bitWidth =
cell.getOutputPorts()[0].getType().getIntOrFloatBitWidth();
// Since Calyx interacts with HardFloat, we'll also only be using expWidth and
// sigWidth. See
// http://www.jhauser.us/arithmetic/HardFloat-1/doc/HardFloat-Verilog.html
unsigned expWidth, sigWidth;
switch (bitWidth) {
case 16:
expWidth = 5;
sigWidth = 11;
break;
case 32:
expWidth = 8;
sigWidth = 24;
break;
case 64:
expWidth = 11;
sigWidth = 53;
break;
case 128:
expWidth = 15;
sigWidth = 113;
break;
default:
op->emitError("The supported bitwidths are 16, 32, 64, and 128");
return;
}

StringRef opName = op->getName().getStringRef();
indent() << getAttributes(op, /*atFormat=*/true) << cell.instanceName()
<< space() << equals() << space() << removeCalyxPrefix(opName)
<< LParen() << expWidth << comma() << sigWidth << comma() << bitWidth
<< RParen() << semicolonEndL();
}

void Emitter::emitAssignment(AssignOp op) {

emitValue(op.getDest(), /*isIndented=*/true);
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,10 +657,10 @@ void InlineCombGroups::recurseInlineCombGroups(
// LateSSAReplacement)
if (isa<BlockArgument>(src) ||
isa<calyx::RegisterOp, calyx::MemoryOp, calyx::SeqMemoryOp,
calyx::ConstantOp, hw::ConstantOp, mlir::arith::ConstantOp,
calyx::MultPipeLibOp, calyx::DivUPipeLibOp, calyx::DivSPipeLibOp,
calyx::RemSPipeLibOp, calyx::RemUPipeLibOp, mlir::scf::WhileOp,
calyx::InstanceOp>(src.getDefiningOp()))
hw::ConstantOp, mlir::arith::ConstantOp, calyx::MultPipeLibOp,
calyx::DivUPipeLibOp, calyx::DivSPipeLibOp, calyx::RemSPipeLibOp,
calyx::RemUPipeLibOp, mlir::scf::WhileOp, calyx::InstanceOp,
calyx::ConstantOp, calyx::AddFNOp>(src.getDefiningOp()))
continue;

auto srcCombGroup = dyn_cast<calyx::CombGroupOp>(
Expand Down
24 changes: 24 additions & 0 deletions test/Conversion/SCFToCalyx/convert_simple.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,27 @@ module {
return %arg0, %0, %1 : f32, i32, f32
}
}

// -----

// Test floating point add

// CHECK: calyx.group @bb0_0 {
// CHECK-DAG: calyx.assign %std_addFN_0.left = %in0 : f32
// CHECK-DAG: calyx.assign %std_addFN_0.right = %cst : f32
// CHECK-DAG: calyx.assign %addf_0_reg.in = %std_addFN_0.out : f32
// CHECK-DAG: calyx.assign %addf_0_reg.write_en = %std_addFN_0.done : i1
// CHECK-DAG: %0 = comb.xor %std_addFN_0.done, %true : i1
// CHECK-DAG: calyx.assign %std_addFN_0.go = %0 ? %true : i1
// CHECK-DAG: calyx.assign %std_addFN_0.subOp = %false : i1
// CHECK-DAG: calyx.group_done %addf_0_reg.done : i1
// CHECK-DAG: }

module {
func.func @main(%arg0 : f32) -> f32 {
%0 = arith.constant 4.2 : f32
%1 = arith.addf %arg0, %0 : f32

return %1 : f32
}
}
10 changes: 0 additions & 10 deletions test/Conversion/SCFToCalyx/errors.mlir
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
// RUN: circt-opt --lower-scf-to-calyx %s -split-input-file -verify-diagnostics

module {
func.func @f(%arg0 : f32, %arg1 : f32) -> f32 {
// expected-error @+1 {{failed to legalize operation 'arith.addf' that was explicitly marked illegal}}
%2 = arith.addf %arg0, %arg1 : f32
return %2 : f32
}
}

// -----

// expected-error @+1 {{Module contains multiple functions, but no top level function was set. Please see --top-level-function}}
module {
func.func @f1() {
Expand Down
51 changes: 51 additions & 0 deletions test/Dialect/Calyx/emit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,54 @@ module attributes {calyx.entrypoint = "main"} {
} {toplevel}
}


// -----

module attributes {calyx.entrypoint = "main"} {
// CHECK: import "primitives/float/addFN.futil";
calyx.component @main(%in0: f32, %clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: f32, %done: i1 {done}) {
// CHECK: std_addFN_0 = std_addFN(8, 24, 32);
%cst = calyx.constant {sym_name = "cst_0"} 4.200000e+00 : f32
%true = hw.constant true
%false = hw.constant false
%addf_0_reg.in, %addf_0_reg.write_en, %addf_0_reg.clk, %addf_0_reg.reset, %addf_0_reg.out, %addf_0_reg.done = calyx.register @addf_0_reg : f32, i1, i1, i1, f32, i1
%std_addFN_0.clk, %std_addFN_0.reset, %std_addFN_0.go, %std_addFN_0.control, %std_addFN_0.subOp, %std_addFN_0.left, %std_addFN_0.right, %std_addFN_0.roundingMode, %std_addFN_0.out, %std_addFN_0.exceptionalFlags, %std_addFN_0.done = calyx.std_addFN @std_addFN_0 : i1, i1, i1, i1, i1, f32, f32, i3, f32, i5, i1
%ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : f32, i1, i1, i1, f32, i1
calyx.wires {
calyx.assign %out0 = %ret_arg0_reg.out : f32

// CHECK-LABEL: group bb0_0 {
// CHECK-NEXT: std_addFN_0.left = in0;
// CHECK-NEXT: std_addFN_0.right = cst_0.out;
// CHECK-NEXT: addf_0_reg.in = std_addFN_0.out;
// CHECK-NEXT: addf_0_reg.write_en = std_addFN_0.done;
// CHECK-NEXT: std_addFN_0.go = !std_addFN_0.done ? 1'd1;
// CHECK-NEXT: std_addFN_0.subOp = 1'd0;
// CHECK-NEXT: bb0_0[done] = addf_0_reg.done;
// CHECK-NEXT: }
calyx.group @bb0_0 {
calyx.assign %std_addFN_0.left = %in0 : f32
calyx.assign %std_addFN_0.right = %cst : f32
calyx.assign %addf_0_reg.in = %std_addFN_0.out : f32
calyx.assign %addf_0_reg.write_en = %std_addFN_0.done : i1
%0 = comb.xor %std_addFN_0.done, %true : i1
calyx.assign %std_addFN_0.go = %0 ? %true : i1
calyx.assign %std_addFN_0.subOp = %false : i1
calyx.group_done %addf_0_reg.done : i1
}
calyx.group @ret_assign_0 {
calyx.assign %ret_arg0_reg.in = %std_addFN_0.out : f32
calyx.assign %ret_arg0_reg.write_en = %true : i1
calyx.group_done %ret_arg0_reg.done : i1
}
}
calyx.control {
calyx.seq {
calyx.seq {
calyx.enable @bb0_0
calyx.enable @ret_assign_0
}
}
}
} {toplevel}
}
Loading