-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add f8E8M0FNU type #111028
[MLIR] Add f8E8M0FNU type #111028
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Sergey Kozub (sergey-kozub) ChangesThis PR adds
f8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0 − 127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111
Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1 Related PRs:
Patch is 20.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111028.diff 24 Files Affected:
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 6dc25a56b8e614..6875fab7bf7961 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -179,6 +179,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
+/// Returns the typeID of an Float8E8M0FNU type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void);
+
+/// Checks whether the given type is an f8E8M0FNU type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type);
+
+/// Creates an f8E8M0FNU type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx);
+
/// Returns the typeID of an BFloat16 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index ee5d7879625309..04a8bddc3cd59a 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -70,6 +70,7 @@ class Builder {
FloatType getFloat8E4M3FNUZType();
FloatType getFloat8E4M3B11FNUZType();
FloatType getFloat8E3M4Type();
+ FloatType getFloat8E8M0FNUType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 91e68b4066dd67..25535408f4528a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -70,6 +70,7 @@ class FloatType : public Type {
static FloatType getFloat4E2M1FN(MLIRContext *ctx);
static FloatType getFloat6E2M3FN(MLIRContext *ctx);
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
+ static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
@@ -416,11 +417,12 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}
inline bool FloatType::classof(Type type) {
- return llvm::isa<
- Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType, Float8E5M2Type,
- Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, Float8E4M3FNUZType,
- Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type, Float16Type,
- FloatTF32Type, Float32Type, Float64Type, Float80Type, Float128Type>(type);
+ return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
+ Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+ Float8E5M2FNUZType, Float8E4M3FNUZType,
+ Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
+ BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
+ Float64Type, Float80Type, Float128Type>(type);
}
inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
@@ -463,6 +465,10 @@ inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
return Float8E3M4Type::get(ctx);
}
+inline FloatType FloatType::getFloat8E8M0FNU(MLIRContext *ctx) {
+ return Float8E8M0FNUType::get(ctx);
+}
+
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b2b41b16beec29..dca228097d782d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -296,6 +296,29 @@ def Builtin_Float6E3M2FN : Builtin_FloatType<"Float6E3M2FN", "f6E3M2FN"> {
}];
}
+//===----------------------------------------------------------------------===//
+// Float8E8M0FNUType
+
+def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
+ let summary = "8-bit floating point with 8-bit exponent, no mantissa or sign";
+ let description = [{
+ An 8-bit floating point type with no sign bit, 8 bits exponent and no
+ mantissa. This is not a standard type as defined by IEEE-754; it is intended
+ to be used for representing scaling factors, so it cannot represent zeros
+ and negative numbers. The values it can represent are powers of two in the
+ range [-127,127] and NaN.
+
+ * bit encoding: S0E8M0
+ * exponent bias: 127
+ * infinities: Not supported
+ * NaNs: Supported with all bits set to 1
+ * denormals: Not supported
+
+ Open Compute Project (OCP) microscaling formats (MX) specification:
+ https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+ }];
+}
+
//===----------------------------------------------------------------------===//
// BFloat16Type
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 211385245555ad..48e4c24f838652 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -353,6 +353,8 @@ def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
BuildableType<"$_builder.getFloat6E2M3FNType()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
BuildableType<"$_builder.getFloat6E3M2FNType()">;
+def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
+ BuildableType<"$_builder.getFloat8E8M0FNUType()">;
def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
"complex-type", "::mlir::ComplexType">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 1b52b97f29b5f5..acd0f894abbbe6 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -135,6 +135,7 @@ class Type {
bool isFloat8E4M3FNUZ() const;
bool isFloat8E4M3B11FNUZ() const;
bool isFloat8E3M4() const;
+ bool isFloat8E8M0FNU() const;
bool isBF16() const;
bool isF16() const;
bool isTF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 2b29177b7dff0f..49da8c3dea5fa5 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -104,6 +104,7 @@ TOK_KEYWORD(f8E3M4)
TOK_KEYWORD(f4E2M1FN)
TOK_KEYWORD(f6E2M3FN)
TOK_KEYWORD(f6E3M2FN)
+TOK_KEYWORD(f8E8M0FNU)
TOK_KEYWORD(f128)
TOK_KEYWORD(false)
TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 60903a86ff8ce1..c614eb39b364be 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -49,6 +49,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_f8E4M3FNUZ:
case Token::kw_f8E4M3B11FNUZ:
case Token::kw_f8E3M4:
+ case Token::kw_f8E8M0FNU:
case Token::kw_bf16:
case Token::kw_f16:
case Token::kw_tf32:
@@ -336,6 +337,9 @@ Type Parser::parseNonFunctionType() {
case Token::kw_f8E3M4:
consumeToken(Token::kw_f8E3M4);
return builder.getFloat8E3M4Type();
+ case Token::kw_f8E8M0FNU:
+ consumeToken(Token::kw_f8E8M0FNU);
+ return builder.getFloat8E8M0FNUType();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 5a369b5d4938cb..6f192bc4bffeef 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -331,6 +331,27 @@ class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
}
};
+/// Floating Point Type subclass - Float8E8M0FNUType.
+class PyFloat8E8M0FNUType
+ : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E8M0FNUTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E8M0FNUType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
+ return PyFloat8E8M0FNUType(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a float8_e8m0fnu type.");
+ }
+};
+
/// Floating Point Type subclass - BF16Type.
class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
public:
@@ -953,6 +974,7 @@ void mlir::python::populateIRTypes(py::module &m) {
PyFloat8E4M3B11FNUZType::bind(m);
PyFloat8E5M2FNUZType::bind(m);
PyFloat8E3M4Type::bind(m);
+ PyFloat8E8M0FNUType::bind(m);
PyBF16Type::bind(m);
PyF16Type::bind(m);
PyTF32Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index efc1e857a39c7a..252ff54afe0c5d 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -205,6 +205,18 @@ MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
}
+MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
+ return wrap(Float8E8M0FNUType::getTypeID());
+}
+
+bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
+ return unwrap(type).isFloat8E8M0FNU();
+}
+
+MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
+ return wrap(FloatType::getFloat8E8M0FNU(unwrap(ctx)));
+}
+
MlirTypeID mlirBFloat16TypeGetTypeID() {
return wrap(BFloat16Type::getTypeID());
}
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index fd6369b5bb4ee5..5a92fa839e9847 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -250,7 +250,8 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
- type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN())
+ type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
+ type.isFloat8E8M0FNU())
return IntegerType::get(&getContext(), type.getWidth());
return type;
}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index c0aa16cc0da407..67dcce454f028b 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -370,6 +370,7 @@ std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
.Case("f8E3M4", b.getFloat8E3M4Type())
+ .Case("f8E8M0FNU", b.getFloat8E8M0FNUType())
.Case("bf16", b.getBF16Type())
.Case("f16", b.getF16Type())
.Case("f32", b.getF32Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 7f95f5ace8c00f..96fb66d53fb835 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2588,6 +2588,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
.Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
.Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
+ .Case<Float8E8M0FNUType>([&](Type) { os << "f8E8M0FNU"; })
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
.Case<Float16Type>([&](Type) { os << "f16"; })
.Case<FloatTF32Type>([&](Type) { os << "tf32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7aed415343e551..a9bc3c0ef65a23 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -74,6 +74,10 @@ FloatType Builder::getFloat8E3M4Type() {
return FloatType::getFloat8E3M4(context);
}
+FloatType Builder::getFloat8E8M0FNUType() {
+ return FloatType::getFloat8E8M0FNU(context);
+}
+
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 782a32b3074680..25e9f80c9963cb 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -121,6 +121,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
return APFloat::Float8E4M3B11FNUZ();
if (llvm::isa<Float8E3M4Type>(*this))
return APFloat::Float8E3M4();
+ if (llvm::isa<Float8E8M0FNUType>(*this))
+ return APFloat::Float8E8M0FNU();
if (llvm::isa<BFloat16Type>(*this))
return APFloat::BFloat();
if (llvm::isa<Float16Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f45de17dd24910..f05666fcde207b 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -231,6 +231,7 @@ class MLIRContextImpl {
Float8E4M3FNUZType f8E4M3FNUZTy;
Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
Float8E3M4Type f8E3M4Ty;
+ Float8E8M0FNUType f8E8M0FNUTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
FloatTF32Type tf32Ty;
@@ -326,6 +327,7 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
+ impl->f8E8M0FNUTy = TypeUniquer::get<Float8E8M0FNUType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1049,6 +1051,9 @@ Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
return context->getImpl().f8E3M4Ty;
}
+Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) {
+ return context->getImpl().f8E8M0FNUTy;
+}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index efefbc299a91f3..e190902b2e4898 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -49,6 +49,9 @@ bool Type::isFloat8E4M3FNUZ() const {
bool Type::isFloat8E4M3B11FNUZ() const {
return llvm::isa<Float8E4M3B11FNUZType>(*this);
}
+bool Type::isFloat8E8M0FNU() const {
+ return llvm::isa<Float8E8M0FNUType>(*this);
+}
bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); }
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 41ed84e0467254..fb7efb8cd28a5e 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -117,6 +117,7 @@ __all__ = [
"Float8E4M3Type",
"Float8E5M2FNUZType",
"Float8E5M2Type",
+ "Float8E8M0FNUType",
"FloatAttr",
"FloatTF32Type",
"FloatType",
@@ -1660,6 +1661,19 @@ class Float8E5M2Type(FloatType):
@property
def typeid(self) -> TypeID: ...
+class Float8E8M0FNUType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float8E8M0FNUType:
+ """
+ Create a float8_e8m0fnu type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
class FloatAttr(Attribute):
static_typeid: ClassVar[TypeID]
@staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index 5b24a6d526f2f8..34eee1edb57ff5 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -20,6 +20,7 @@
Float8E4M3FNType,
Float8E4M3Type,
Float8E5M2Type,
+ Float8E8M0FNUType,
FunctionType,
IndexType,
IntegerType,
@@ -80,6 +81,7 @@ def ui(width):
f4E2M1FN = lambda: Float4E2M1FNType.get()
f6E2M3FN = lambda: Float6E2M3FNType.get()
f6E3M2FN = lambda: Float6E3M2FNType.get()
+f8E8M0FNU = lambda: Float8E8M0FNUType.get()
none = lambda: NoneType.get()
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 31a4663f72e6e9..a62de3f5004d73 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -76,6 +76,10 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f8E3M4
float_attr = 2. : f8E3M4
} : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 2.000000e+00 : f8E8M0FNU
+ float_attr = 2. : f8E8M0FNU
+ } : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f16
float_attr = 2. : f16
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 327c9f05f4c72c..c884f83cb4d32d 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -72,6 +72,9 @@ llvm.mlir.global internal @f8E5M2FNUZ_global_as_i8(1.5 : f8E5M2FNUZ) : i8
// CHECK: @f8E4M3B11FNUZ_global_as_i8 = internal global i8 92
llvm.mlir.global internal @f8E4M3B11FNUZ_global_as_i8(1.5 : f8E4M3B11FNUZ) : i8
+// CHECK: @f8E8M0FNU_global_as_i8 = internal global i8 127
+llvm.mlir.global internal @f8E8M0FNU_global_as_i8(1.0 : f8E8M0FNU) : i8
+
// CHECK: @bf16_global_as_i16 = internal global i16 16320
llvm.mlir.global internal @bf16_global_as_i16(1.5 : bf16) : i16
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 6154a6ff9e9aed..48ddc8359ca0a1 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -133,6 +133,8 @@ def testFloatTypeSubclasses():
# CHECK: True
print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType))
# CHECK: True
+ print(isinstance(Type.parse("f8E8M0FNU", ctx), FloatType))
+ # CHECK: True
print(isinstance(Type.parse("f16", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("bf16", ctx), FloatType))
@@ -259,6 +261,8 @@ def testFloatType():
print("float:", Float8E4M3FNUZType.get())
# CHECK: float: f8E4M3B11FNUZ
print("float:", Float8E4M3B11FNUZType.get())
+ # CHECK: float: f8E8M0FNU
+ print("float:", Float8E8M0FNUType.get())
# CHECK: float: bf16
print("float:", BF16Type.get())
# CHECK: float: f16
@@ -631,6 +635,7 @@ def testTypeIDs():
(Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
(Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
(Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
+ (Float8E8M0FNUType, Float8E8M0FNUType.get()),
(BF16Type, BF16Type.get()),
(F16Type, F16Type.get()),
(F32Type, F32Type.get()),
@@ -659,6 +664,7 @@ def testTypeIDs():
# CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
# CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
# CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
+ # CHECK: Float8E8M0FNUType(f8E8M0FNU)
# CHECK: BF16Type(bf16)
# CHECK: F16Type(f16)
# CHECK: F32Type(f32)
@@ -761,6 +767,9 @@ def print_downcasted(typ):
# CHECK: Float8E5M2FNUZType
# CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
print_downcasted(Float8E5M2FNUZType.get())
+ # CHECK: Float8E8M0FNUType
+ # CHECK: Float8E8M0FNUType(f8E8M0FNU)
+ print_downcasted(Float8E8M0FNUType.get())
# CHECK: BF16Type
# CHECK: BF16Type(bf16)
print_downcasted(BF16Type.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index 54d3d703640403..38e8278eefbbd3 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -60,6 +60,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
"mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
"mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"',
"mlir::Float8E3M4Type": '"f8E3M4"',
+ "mlir::Float8E8M0FNUType": '"f8E8M0FNU"',
"mlir::BFloat16Type": '"bf16"',
"mlir::Float16Type": '"f16"',
"mlir::FloatTF32Type": '"tf32"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index f7d916dfb57e2f..2dadd46c4760ca 1006...
[truncated]
|
Looks good to me. Follows the same template from the earlier PRs. In the commit message: Can we also add "no denorms" too? |
Added, thanks. |
Fly on the wall, but at what point, upon removing all features that traditionally make something a "floating point number" (mantissa, zero, denorms, infinities) does something no longer make any sense at all being part of a floating point hierarchy. It's just a bit-vector with a special error value. I'm not blocking this in any way or even asking seriously. Just kind of balking at the cargo cult mentality that is going into bundling these things together like this. |
I had the same gut reaction... We are now up to ~18 floating point types... That kind of points to a serious issue with the way things are scaling here, and I think we should really rethink what's being done (especially given that each one of these PRs are identical, add big chunks of code in the core library, etc). |
I also wondered what makes it a floating point number. For E8M0, mantissa is there but is implicit (has 1 bit which has value of one) - other FP types also have an implicit bit of data. The E8M0 is intended to be used as a scaling factor in block scaled formats like MXFP8, which is exactly why it doesn't have negatives, infinities or zeros - none of these makes sense for a scaling factor. |
Ok, I didn't want to unilaterally hit pause, but it seems like we've got a number of people with the same analysis. Should we at least discuss this a bit more? Or proceed with this patch? I agree that the path we're on is not very sustainable. |
The amount of boilerplate code is annoying, I believe this could (and should) be generalized. |
Yeah, it was never meant to scale beyond the primary fp8 types. Needs a rethink... if not in this case, certainly soon. |
For example, https://arxiv.org/html/2405.13938v1 mentions more esoterics like |
Just judging by the mood and temperament of the industry, we'll end up with just about every combination before too long. Might as well try to structure the code for that eventuality vs being the victim of it. |
Yeah, we’re definitely dealing with a combinatorial explosion. For the compiler internals, have we thought about using a single FP type that models different E/M sizes and supports the required feature flags for the specific semantics of these formats? I wonder, though, if we need verification at some level (bindings?) to ensure compliance with the micro-scaling standard... Or do we want to leave the door open to arbitrary format? I'm not sure I fully understand the implications of the latter.
I would say they are still floating-point in spirit :). I’m not sure we could treat them differently. There are specific FP semantics that still apply to these formats, such as NaN propagation (e.g., minnum/maxnum/minimum/maximum), exceptions, overflow... |
Is there any hardware that supports arithmetic over these types? At least on the CDNA side, fp8 support is limited to conversion to/from other 'usual' fp types and dedicated matmul intrinsics. If we had IR that did any arithmetic on fp types, we wouldn't be able to compile it without software emulation anyway. Or in other words, I wonder if some level of general hw support (not just conversions + matmul) would be a good criterium for inclusion in core MLIR types. |
I can't defend my opinion rigorously, but I think that many of these "software only" types are leaking in from a well intentioned reading of the MX spec and an attempt to mirror that in MLIR. But the MX spec has very little to say about how the machines implement such things -- and the early state of implementations means that few of us fine purveyors of such machines are saying a lot of concrete details about the underlying implementations yet. My observation is that this has made it hard to plan the right level of support in MLIR, as the norm in this ecosystem is to work bottom up. But that is hard to do when the bottom is not defined in a way that everyone can see. And that is probably causing (again well-intentioned) premature generalization. As Jakub says, many of these things are going to get defined in terms of conversion intrinsics (or emulation/helper functions) and a very small number of ways that they can actually interact with the hardware. I'd have to go dig through the original discussion on FP8, but a key point of that was that those types had a concrete realization in hardware and that the compiler was better off fully modeling such a thing (vs type punning, etc). I'm not sure that the same line of reasoning necessarily holds for all of these software/emulation variants. But like I said, I'm not sure I can defend that position rigorously. In the absence of that, I think my core concern is that we may want a way in the code to better model software/emulated FP types (i.e. so that it doesn't require so much code, duplication, etc). |
Everyone has some degree of intrinsic support for the MX types, just like most things. But the nature of that support is pretty different from typical FP types. In any case, I'm not against adding these types or even landing this PR. But I think the approach needs adjustment before this gets further out of hand. |
This PR adds `f8E8M0FNU` type to MLIR. `f8E8M0FNU` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 8-bit floating point number with bit layout S0E8M0. Unlike IEEE-754 types, there are no infinity, denormals, zeros or negative values. ```c f8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - [PR-107127](llvm#107127) [APFloat] Add APFloat support for E8M0 type - [PR-105573](llvm#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR - [PR-107999](llvm#107999) [MLIR] Add f6E2M3FN type - [PR-108877](llvm#108877) [MLIR] Add f4E2M1FN type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this doesn't seem correctly roundtrip through bytecode. If checking IR/attribute.mlir
"test.float_attrs"() <{float_attr = 2.000000e+00 : f8E8M0FNU}> : () -> () loc(#loc12)
becomes
"test.float_attrs"() <{float_attr = 4.000000e+00 : f8E8M0FNU}> : () -> () loc(#loc12)
Can test with MLIR_OPT_CHECK_IR_ROUNDTRIP=1
in testing environment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can confirm the issue, looking into this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix in #113298
Fixes a bug in APFloat handling of E8M0 type (zero mantissa). Related PRs: - llvm#107127 - llvm#111028
…U) (#2581) This is a proposal to add MX (microscaling) floating point types to StableHLO. Related links: - StableHLO [PR#2582](#2582) Add MX floating point types (f4E2M1FN, f6E2M3FN, f6E3M2FN, f8E8M0FNU) - LLVM [PR#95392](llvm/llvm-project#95392) [APFloat] Add APFloat support for FP4 data type - LLVM [PR#94735](llvm/llvm-project#94735) [APFloat] Add APFloat support for FP6 data types - LLVM [PR#107127](llvm/llvm-project#107127) [APFloat] Add APFloat support for E8M0 type - LLVM [PR#108877](llvm/llvm-project#108877) [MLIR] Add f4E2M1FN type - LLVM [PR#107999](llvm/llvm-project#107999) [MLIR] Add f6E2M3FN type - LLVM [PR#105573](llvm/llvm-project#105573) [MLIR] Add f6E3M2FN type - LLVM [PR#111028](llvm/llvm-project#111028) [MLIR] Add f8E8M0FNU type - JAX-ML [PR#181](jax-ml/ml_dtypes#181) Add sub-byte data types: float4_e2m1fn, float6_e2m3fn, float6_e3m2fn - JAX-ML [PR#166](jax-ml/ml_dtypes#181) Add float8_e8m0_fnu (E8M0) OCP MX scale format
Fixes a bug in APFloat handling of E8M0 type (zero mantissa). Related PRs: - llvm#107127 - llvm#111028
This PR adds
f8E8M0FNU
type to MLIR.f8E8M0FNU
type is proposed in OpenCompute MX Specification. It defines a 8-bit floating point number with bit layout S0E8M0. Unlike IEEE-754 types, there are no infinity, denormals, zeros or negative values.Related PRs: