Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add f4E2M1FN type #108877

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir-c/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
/// Returns the bitwidth of a floating-point type.
MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);

/// Returns the typeID of an Float4E2M1FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat4E2M1FNTypeGetTypeID(void);

/// Checks whether the given type is an f4E2M1FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat4E2M1FN(MlirType type);

/// Creates an f4E2M1FN type in the given context. The type is owned by the
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx);

/// Returns the typeID of an Float6E2M3FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void);

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Builder {
Attribute metadata = Attribute());

// Types.
FloatType getFloat4E2M1FNType();
FloatType getFloat6E2M3FNType();
FloatType getFloat6E3M2FNType();
FloatType getFloat8E5M2Type();
Expand Down
15 changes: 10 additions & 5 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class FloatType : public Type {
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
static FloatType getFloat8E3M4(MLIRContext *ctx);
static FloatType getFloat4E2M1FN(MLIRContext *ctx);
static FloatType getFloat6E2M3FN(MLIRContext *ctx);
static FloatType getFloat6E3M2FN(MLIRContext *ctx);

Expand Down Expand Up @@ -415,11 +416,15 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}

inline bool FloatType::classof(Type type) {
return llvm::isa<Float6E2M3FNType, Float6E3M2FNType, Float8E5M2Type,
Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
Float64Type, Float80Type, Float128Type>(type);
return llvm::isa<
Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType, Float8E5M2Type,
Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, Float8E4M3FNUZType,
Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type, Float16Type,
FloatTF32Type, Float32Type, Float64Type, Float80Type, Float128Type>(type);
}

inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
return Float4E2M1FNType::get(ctx);
}

inline FloatType FloatType::getFloat6E2M3FN(MLIRContext *ctx) {
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,27 @@ def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
}];
}

//===----------------------------------------------------------------------===//
// Float4E2M1FNType

def Builtin_Float4E2M1FN : Builtin_FloatType<"Float4E2M1FN", "f4E2M1FN"> {
let summary = "4-bit floating point with 2-bit exponent and 1-bit mantissa";
let description = [{
An 4-bit floating point type with 1 sign bit, 2 bits exponent and 1 bit
mantissa. This is not a standard type as defined by IEEE-754, but it
follows similar conventions with the following characteristics:

* bit encoding: S1E2M1
* exponent bias: 1
* infinities: Not supported
* NaNs: Not supported
* denormals when exponent is 0

Open Compute Project (OCP) microscaling formats (MX) specification:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
}];
}

//===----------------------------------------------------------------------===//
// Float6E2M3FNType

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
BuildableType<"$_builder.getFloat8E3M4Type()">;
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
BuildableType<"$_builder.getFloat4E2M1FNType()">;
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
BuildableType<"$_builder.getFloat6E2M3FNType()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
bool isFloat4E2M1FN() const;
bool isFloat6E2M3FN() const;
bool isFloat6E3M2FN() const;
bool isFloat8E5M2() const;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/AsmParser/TokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ TOK_KEYWORD(f8E5M2FNUZ)
TOK_KEYWORD(f8E4M3FNUZ)
TOK_KEYWORD(f8E4M3B11FNUZ)
TOK_KEYWORD(f8E3M4)
TOK_KEYWORD(f4E2M1FN)
TOK_KEYWORD(f6E2M3FN)
TOK_KEYWORD(f6E3M2FN)
TOK_KEYWORD(f128)
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/AsmParser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_tuple:
case Token::kw_vector:
case Token::inttype:
case Token::kw_f4E2M1FN:
case Token::kw_f6E2M3FN:
case Token::kw_f6E3M2FN:
case Token::kw_f8E5M2:
Expand Down Expand Up @@ -305,6 +306,9 @@ Type Parser::parseNonFunctionType() {
}

// float-type
case Token::kw_f4E2M1FN:
consumeToken(Token::kw_f4E2M1FN);
return builder.getFloat4E2M1FNType();
case Token::kw_f6E2M3FN:
consumeToken(Token::kw_f6E2M3FN);
return builder.getFloat6E2M3FNType();
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,27 @@ class PyFloatType : public PyConcreteType<PyFloatType> {
}
};

/// Floating Point Type subclass - Float4E2M1FNType.
class PyFloat4E2M1FNType
: public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat4E2M1FNTypeGetTypeID;
static constexpr const char *pyClassName = "Float4E2M1FNType";
using PyConcreteType::PyConcreteType;

static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
return PyFloat4E2M1FNType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float4_e2m1fn type.");
}
};

/// Floating Point Type subclass - Float6E2M3FNType.
class PyFloat6E2M3FNType
: public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
Expand Down Expand Up @@ -922,6 +943,7 @@ void mlir::python::populateIRTypes(py::module &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
PyIndexType::bind(m);
PyFloat4E2M1FNType::bind(m);
PyFloat6E2M3FNType::bind(m);
PyFloat6E3M2FNType::bind(m);
PyFloat8E4M3FNType::bind(m);
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/CAPI/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ unsigned mlirFloatTypeGetWidth(MlirType type) {
return llvm::cast<FloatType>(unwrap(type)).getWidth();
}

MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() {
return wrap(Float4E2M1FNType::getTypeID());
}

bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
return unwrap(type).isFloat4E2M1FN();
}

MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat4E2M1FN(unwrap(ctx)));
}

MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
return wrap(Float6E2M3FNType::getTypeID());
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
type.isFloat6E2M3FN() || type.isFloat6E3M2FN())
type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN())
return IntegerType::get(&getContext(), type.getWidth());
return type;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
.Case("f4E2M1FN", b.getFloat4E2M1FNType())
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
.Case("f8E5M2", b.getFloat8E5M2Type())
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2575,6 +2575,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
opaqueTy.getTypeData());
})
.Case<IndexType>([&](Type) { os << "index"; })
.Case<Float4E2M1FNType>([&](Type) { os << "f4E2M1FN"; })
.Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; })
.Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//

FloatType Builder::getFloat4E2M1FNType() {
return FloatType::getFloat4E2M1FN(context);
}

FloatType Builder::getFloat6E2M3FNType() {
return FloatType::getFloat6E2M3FN(context);
}
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ unsigned FloatType::getWidth() {

/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() {
if (llvm::isa<Float4E2M1FNType>(*this))
return APFloat::Float4E2M1FN();
if (llvm::isa<Float6E2M3FNType>(*this))
return APFloat::Float6E2M3FN();
if (llvm::isa<Float6E3M2FNType>(*this))
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class MLIRContextImpl {
llvm::DenseMap<StringRef, AbstractType *> nameToType;

/// Cached Type Instances.
Float4E2M1FNType f4E2M1FNTy;
Float6E2M3FNType f6E2M3FNTy;
Float6E3M2FNType f6E3M2FNTy;
Float8E5M2Type f8E5M2Ty;
Expand Down Expand Up @@ -315,6 +316,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)

//// Types.
/// Floating-point Types.
impl->f4E2M1FNTy = TypeUniquer::get<Float4E2M1FNType>(this);
impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
Expand Down Expand Up @@ -1017,6 +1019,9 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
/// This should not be used directly.
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }

Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) {
return context->getImpl().f4E2M1FNTy;
}
Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
return context->getImpl().f6E2M3FNTy;
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,

MLIRContext *Type::getContext() const { return getDialect().getContext(); }

bool Type::isFloat4E2M1FN() const { return llvm::isa<Float4E2M1FNType>(*this); }
bool Type::isFloat6E2M3FN() const { return llvm::isa<Float6E2M3FNType>(*this); }
bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); }
bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
Expand Down
14 changes: 14 additions & 0 deletions mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ __all__ = [
"F32Type",
"F64Type",
"FlatSymbolRefAttr",
"Float4E2M1FNType",
"Float6E2M3FNType",
"Float6E3M2FNType",
"Float8E3M4Type",
Expand Down Expand Up @@ -1542,6 +1543,19 @@ class FlatSymbolRefAttr(Attribute):
Returns the value of the FlatSymbolRef attribute as a string
"""

class Float4E2M1FNType(FloatType):
static_typeid: ClassVar[TypeID]
@staticmethod
def get(context: Optional[Context] = None) -> Float4E2M1FNType:
"""
Create a float4_e2m1fn type.
"""
@staticmethod
def isinstance(other: Type) -> bool: ...
def __init__(self, cast_from_type: Type) -> None: ...
@property
def typeid(self) -> TypeID: ...

class Float6E2M3FNType(FloatType):
static_typeid: ClassVar[TypeID]
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions mlir/python/mlir/extras/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
F16Type,
F32Type,
F64Type,
Float4E2M1FNType,
Float6E2M3FNType,
Float6E3M2FNType,
Float8E3M4Type,
Expand Down Expand Up @@ -76,6 +77,7 @@ def ui(width):
f8E4M3FN = lambda: Float8E4M3FNType.get()
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
f8E3M4 = lambda: Float8E3M4Type.get()
f4E2M1FN = lambda: Float4E2M1FNType.get()
f6E2M3FN = lambda: Float6E2M3FNType.get()
f6E3M2FN = lambda: Float6E3M2FNType.get()

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/IR/attribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func.func @any_attr_of_fail() {
//===----------------------------------------------------------------------===//

func.func @float_attrs_pass() {
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f4E2M1FN
float_attr = 2. : f4E2M1FN
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f6E2M3FN
float_attr = 2. : f6E2M3FN
Expand Down
3 changes: 3 additions & 0 deletions mlir/test/Target/LLVMIR/llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ llvm.mlir.global internal @int_global_undef() : i64
// CHECK: @externally_initialized_global = internal externally_initialized global i32 0
llvm.mlir.global internal @externally_initialized_global(0 : i32) {externally_initialized} : i32

// CHECK: @f4E2M1FN_global_as_i4 = internal global i4 3
llvm.mlir.global internal @f4E2M1FN_global_as_i4(1.5 : f4E2M1FN) : i4

// CHECK: @f6E2M3FN_global_as_i6 = internal global i6 12
llvm.mlir.global internal @f6E2M3FN_global_as_i6(1.5 : f6E2M3FN) : i6

Expand Down
9 changes: 9 additions & 0 deletions mlir/test/python/ir/builtin_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def testTypeIsInstance():
def testFloatTypeSubclasses():
ctx = Context()
# CHECK: True
print(isinstance(Type.parse("f4E2M1FN", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f6E2M3FN", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType))
Expand Down Expand Up @@ -237,6 +239,8 @@ def testIndexType():
@run
def testFloatType():
with Context():
# CHECK: float: f4E2M1FN
print("float:", Float4E2M1FNType.get())
# CHECK: float: f6E2M3FN
print("float:", Float6E2M3FNType.get())
# CHECK: float: f6E3M2FN
Expand Down Expand Up @@ -617,6 +621,7 @@ def testTypeIDs():
types = [
(IntegerType, IntegerType.get_signless(16)),
(IndexType, IndexType.get()),
(Float4E2M1FNType, Float4E2M1FNType.get()),
(Float6E2M3FNType, Float6E2M3FNType.get()),
(Float6E3M2FNType, Float6E3M2FNType.get()),
(Float8E3M4Type, Float8E3M4Type.get()),
Expand Down Expand Up @@ -644,6 +649,7 @@ def testTypeIDs():

# CHECK: IntegerType(i16)
# CHECK: IndexType(index)
# CHECK: Float4E2M1FNType(f4E2M1FN)
# CHECK: Float6E2M3FNType(f6E2M3FN)
# CHECK: Float6E3M2FNType(f6E3M2FN)
# CHECK: Float8E3M4Type(f8E3M4)
Expand Down Expand Up @@ -725,6 +731,9 @@ def print_downcasted(typ):
# CHECK: F64Type
# CHECK: F64Type(f64)
print_downcasted(F64Type.get())
# CHECK: Float4E2M1FNType
# CHECK: Float4E2M1FNType(f4E2M1FN)
print_downcasted(Float4E2M1FNType.get())
# CHECK: Float6E2M3FNType
# CHECK: Float6E2M3FNType(f6E2M3FN)
print_downcasted(Float6E2M3FNType.get())
Expand Down
1 change: 1 addition & 0 deletions mlir/utils/lldb-scripts/mlirDataFormatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
"mlir::CallSiteLoc": '"loc(callsite(...))"',
"mlir::FusedLoc": '"loc(fused<...>[...])"',
"mlir::UnknownLoc": '"loc(unknown)"',
"mlir::Float4E2M1FNType": '"f4E2M1FN"',
"mlir::Float6E2M3FNType": '"f6E2M3FN"',
"mlir::Float6E3M2FNType": '"f6E3M2FN"',
"mlir::Float8E5M2Type": '"f8E5M2"',
Expand Down
2 changes: 1 addition & 1 deletion mlir/utils/tree-sitter-mlir/grammar.js
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ const common = {
token(seq(choice('si', 'ui', 'i'), /[1-9]/, repeat(/[0-9]/))),
float_type : $ => token(
choice('f16', 'f32', 'f64', 'f80', 'f128', 'bf16', 'f8E3M4', 'f8E4M3FN',
'f8E4M3', 'f8E5M2', 'f6E2M3FN', 'f6E3M2FN')),
'f8E4M3', 'f8E5M2', 'f4E2M1FN', 'f6E2M3FN', 'f6E3M2FN')),
index_type : $ => token('index'),
none_type : $ => token('none'),
complex_type : $ => seq(token('complex'), '<', $._prim_type, '>'),
Expand Down
Loading