Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add f4E2M1FN type #108877

Merged
merged 1 commit into from
Sep 24, 2024
Merged

[MLIR] Add f4E2M1FN type #108877

merged 1 commit into from
Sep 24, 2024

Conversation

sergey-kozub
Copy link
Contributor

This PR adds f4E2M1FN type to mlir.

f4E2M1FN type is proposed in OpenCompute MX Specification. It defines a 4-bit floating point number with bit layout S1E2M1. Unlike IEEE-754 types, there are no infinity or NaN values.

f4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 11 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5

Related PRs:

  • PR-95392 [APFloat] Add APFloat support for FP4 data type
  • PR-105573 [MLIR] Add f6E3M2FN type - was used as a template for this PR
  • PR-107999 [MLIR] Add f6E2M3FN type

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 17, 2024

@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-ods

Author: Sergey Kozub (sergey-kozub)

Changes

This PR adds f4E2M1FN type to mlir.

f4E2M1FN type is proposed in OpenCompute MX Specification. It defines a 4-bit floating point number with bit layout S1E2M1. Unlike IEEE-754 types, there are no infinity or NaN values.

f4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 11 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5

Related PRs:

  • PR-95392 [APFloat] Add APFloat support for FP4 data type
  • PR-105573 [MLIR] Add f6E3M2FN type - was used as a template for this PR
  • PR-107999 [MLIR] Add f6E2M3FN type

Patch is 20.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108877.diff

24 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+10-5)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+21)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+22)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+2)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+1-1)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index cc6da482a1c369..6dc25a56b8e614 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -79,6 +79,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
 /// Returns the bitwidth of a floating-point type.
 MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);
 
+/// Returns the typeID of an Float4E2M1FN type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat4E2M1FNTypeGetTypeID(void);
+
+/// Checks whether the given type is an f4E2M1FN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat4E2M1FN(MlirType type);
+
+/// Creates an f4E2M1FN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx);
+
 /// Returns the typeID of an Float6E2M3FN type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 196d34e12d9b28..ee5d7879625309 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
                        Attribute metadata = Attribute());
 
   // Types.
+  FloatType getFloat4E2M1FNType();
   FloatType getFloat6E2M3FNType();
   FloatType getFloat6E3M2FNType();
   FloatType getFloat8E5M2Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f2231e9507570e..91e68b4066dd67 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -67,6 +67,7 @@ class FloatType : public Type {
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E3M4(MLIRContext *ctx);
+  static FloatType getFloat4E2M1FN(MLIRContext *ctx);
   static FloatType getFloat6E2M3FN(MLIRContext *ctx);
   static FloatType getFloat6E3M2FN(MLIRContext *ctx);
 
@@ -415,11 +416,15 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return llvm::isa<Float6E2M3FNType, Float6E3M2FNType, Float8E5M2Type,
-                   Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
-                   Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
-                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
-                   Float64Type, Float80Type, Float128Type>(type);
+  return llvm::isa<
+      Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType, Float8E5M2Type,
+      Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, Float8E4M3FNUZType,
+      Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type, Float16Type,
+      FloatTF32Type, Float32Type, Float64Type, Float80Type, Float128Type>(type);
+}
+
+inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
+  return Float4E2M1FNType::get(ctx);
 }
 
 inline FloatType FloatType::getFloat6E2M3FN(MLIRContext *ctx) {
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index c283c20f36e91e..c738a8a3becc16 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -233,6 +233,27 @@ def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float4E2M1FNType
+
+def Builtin_Float4E2M1FN : Builtin_FloatType<"Float4E2M1FN", "f4E2M1FN"> {
+  let summary = "4-bit floating point with 2-bit exponent and 1-bit mantissa";
+  let description = [{
+    An 4-bit floating point type with 1 sign bit, 2 bits exponent and 1 bit
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E2M1
+      * exponent bias: 1
+      * infinities: Not supported
+      * NaNs: Not supported
+      * denormals when exponent is 0
+
+    Open Compute Project (OCP) microscaling formats (MX) specification:
+    https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Float6E2M3FNType
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index c852d2cfa730d9..211385245555ad 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -347,6 +347,8 @@ def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
 def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
              BuildableType<"$_builder.getFloat8E3M4Type()">;
+def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
+               BuildableType<"$_builder.getFloat4E2M1FNType()">;
 def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
                BuildableType<"$_builder.getFloat6E2M3FNType()">;
 def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 8b6f365fbda02e..1b52b97f29b5f5 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -125,6 +125,7 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
+  bool isFloat4E2M1FN() const;
   bool isFloat6E2M3FN() const;
   bool isFloat6E3M2FN() const;
   bool isFloat8E5M2() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 6ae64a17d1fadb..2b29177b7dff0f 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -101,6 +101,7 @@ TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f8E4M3B11FNUZ)
 TOK_KEYWORD(f8E3M4)
+TOK_KEYWORD(f4E2M1FN)
 TOK_KEYWORD(f6E2M3FN)
 TOK_KEYWORD(f6E3M2FN)
 TOK_KEYWORD(f128)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index a3798ca8d90b1b..60903a86ff8ce1 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -39,6 +39,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_tuple:
   case Token::kw_vector:
   case Token::inttype:
+  case Token::kw_f4E2M1FN:
   case Token::kw_f6E2M3FN:
   case Token::kw_f6E3M2FN:
   case Token::kw_f8E5M2:
@@ -305,6 +306,9 @@ Type Parser::parseNonFunctionType() {
   }
 
   // float-type
+  case Token::kw_f4E2M1FN:
+    consumeToken(Token::kw_f4E2M1FN);
+    return builder.getFloat4E2M1FNType();
   case Token::kw_f6E2M3FN:
     consumeToken(Token::kw_f6E2M3FN);
     return builder.getFloat6E2M3FNType();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 6b64bc3c9d6f63..5a369b5d4938cb 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -124,6 +124,27 @@ class PyFloatType : public PyConcreteType<PyFloatType> {
   }
 };
 
+/// Floating Point Type subclass - Float4E2M1FNType.
+class PyFloat4E2M1FNType
+    : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat4E2M1FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float4E2M1FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
+          return PyFloat4E2M1FNType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float4_e2m1fn type.");
+  }
+};
+
 /// Floating Point Type subclass - Float6E2M3FNType.
 class PyFloat6E2M3FNType
     : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
@@ -922,6 +943,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyIntegerType::bind(m);
   PyFloatType::bind(m);
   PyIndexType::bind(m);
+  PyFloat4E2M1FNType::bind(m);
   PyFloat6E2M3FNType::bind(m);
   PyFloat6E3M2FNType::bind(m);
   PyFloat8E4M3FNType::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index f943bf726b172c..efc1e857a39c7a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -85,6 +85,18 @@ unsigned mlirFloatTypeGetWidth(MlirType type) {
   return llvm::cast<FloatType>(unwrap(type)).getWidth();
 }
 
+MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() {
+  return wrap(Float4E2M1FNType::getTypeID());
+}
+
+bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
+  return unwrap(type).isFloat4E2M1FN();
+}
+
+MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat4E2M1FN(unwrap(ctx)));
+}
+
 MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
   return wrap(Float6E2M3FNType::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 51a1b91338c6a0..fd6369b5bb4ee5 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -250,7 +250,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
       type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
-      type.isFloat6E2M3FN() || type.isFloat6E3M2FN())
+      type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 5e5e10b1fa1c2b..0bf8c8942885e6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -55,6 +55,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
                                                StringRef name) {
   Builder b(ctx);
   return llvm::StringSwitch<std::optional<FloatType>>(name)
+      .Case("f4E2M1FN", b.getFloat4E2M1FNType())
       .Case("f6E2M3FN", b.getFloat6E2M3FNType())
       .Case("f6E3M2FN", b.getFloat6E3M2FNType())
       .Case("f8E5M2", b.getFloat8E5M2Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index c7ed158aabb6e7..4cdc2f64fbd8b5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2575,6 +2575,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
                            opaqueTy.getTypeData());
       })
       .Case<IndexType>([&](Type) { os << "index"; })
+      .Case<Float4E2M1FNType>([&](Type) { os << "f4E2M1FN"; })
       .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; })
       .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
       .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 144a13df2179b7..7aed415343e551 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,6 +34,10 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
 // Types.
 //===----------------------------------------------------------------------===//
 
+FloatType Builder::getFloat4E2M1FNType() {
+  return FloatType::getFloat4E2M1FN(context);
+}
+
 FloatType Builder::getFloat6E2M3FNType() {
   return FloatType::getFloat6E2M3FN(context);
 }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 702d98ec31427b..782a32b3074680 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -101,6 +101,8 @@ unsigned FloatType::getWidth() {
 
 /// Returns the floating semantics for the given type.
 const llvm::fltSemantics &FloatType::getFloatSemantics() {
+  if (llvm::isa<Float4E2M1FNType>(*this))
+    return APFloat::Float4E2M1FN();
   if (llvm::isa<Float6E2M3FNType>(*this))
     return APFloat::Float6E2M3FN();
   if (llvm::isa<Float6E3M2FNType>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 1684566626886c..f45de17dd24910 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -221,6 +221,7 @@ class MLIRContextImpl {
   llvm::DenseMap<StringRef, AbstractType *> nameToType;
 
   /// Cached Type Instances.
+  Float4E2M1FNType f4E2M1FNTy;
   Float6E2M3FNType f6E2M3FNTy;
   Float6E3M2FNType f6E3M2FNTy;
   Float8E5M2Type f8E5M2Ty;
@@ -315,6 +316,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
 
   //// Types.
   /// Floating-point Types.
+  impl->f4E2M1FNTy = TypeUniquer::get<Float4E2M1FNType>(this);
   impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
   impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
   impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
@@ -1017,6 +1019,9 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
 /// This should not be used directly.
 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 
+Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) {
+  return context->getImpl().f4E2M1FNTy;
+}
 Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
   return context->getImpl().f6E2M3FNTy;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index c828fd3766eaa7..efefbc299a91f3 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -34,6 +34,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
+bool Type::isFloat4E2M1FN() const { return llvm::isa<Float4E2M1FNType>(*this); }
 bool Type::isFloat6E2M3FN() const { return llvm::isa<Float6E2M3FNType>(*this); }
 bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); }
 bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index ea5c96dcbc6c11..4d5b4cef9d8aa8 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float4E2M1FNType",
     "Float6E2M3FNType",
     "Float6E3M2FNType",
     "Float8E3M4Type",
@@ -1542,6 +1543,19 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
+class Float4E2M1FNType(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Optional[Context] = None) -> Float4E2M1FNType:
+        """
+        Create a float4_e2m1fn type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class Float6E2M3FNType(FloatType):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index 4be425f220c978..5b24a6d526f2f8 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
     F16Type,
     F32Type,
     F64Type,
+    Float4E2M1FNType,
     Float6E2M3FNType,
     Float6E3M2FNType,
     Float8E3M4Type,
@@ -76,6 +77,7 @@ def ui(width):
 f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
 f8E3M4 = lambda: Float8E3M4Type.get()
+f4E2M1FN = lambda: Float4E2M1FNType.get()
 f6E2M3FN = lambda: Float6E2M3FNType.get()
 f6E3M2FN = lambda: Float6E3M2FNType.get()
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 23dbf0c292c2c3..31a4663f72e6e9 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -36,6 +36,10 @@ func.func @any_attr_of_fail() {
 //===----------------------------------------------------------------------===//
 
 func.func @float_attrs_pass() {
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f4E2M1FN
+    float_attr = 2. : f4E2M1FN
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f6E2M3FN
     float_attr = 2. : f6E2M3FN
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 7eca1a40373054..ea91efa469088a 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -42,6 +42,9 @@ llvm.mlir.global internal @int_global_undef() : i64
 // CHECK: @externally_initialized_global = internal externally_initialized global i32 0
 llvm.mlir.global internal @externally_initialized_global(0 : i32) {externally_initialized} : i32
 
+// CHECK: @f4E2M1FN_global_as_i4 = internal global i4 3
+llvm.mlir.global internal @f4E2M1FN_global_as_i4(1.5 : f4E2M1FN) : i4
+
 // CHECK: @f6E2M3FN_global_as_i6 = internal global i6 12
 llvm.mlir.global internal @f6E2M3FN_global_as_i6(1.5 : f6E2M3FN) : i6
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index bc3ba4cd0b1448..6154a6ff9e9aed 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
 def testFloatTypeSubclasses():
     ctx = Context()
     # CHECK: True
+    print(isinstance(Type.parse("f4E2M1FN", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f6E2M3FN", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType))
@@ -237,6 +239,8 @@ def testIndexType():
 @run
 def testFloatType():
     with Context():
+        # CHECK: float: f4E2M1FN
+        print("float:", Float4E2M1FNType.get())
         # CHECK: float: f6E2M3FN
         print("float:", Float6E2M3FNType.get())
         # CHECK: float: f6E3M2FN
@@ -617,6 +621,7 @@ def testTypeIDs():
         types = [
             (IntegerType, IntegerType.get_signless(16)),
             (IndexType, IndexType.get()),
+            (Float4E2M1FNType, Float4E2M1FNType.get()),
             (Float6E2M3FNType, Float6E2M3FNType.get()),
             (Float6E3M2FNType, Float6E3M2FNType.get()),
             (Float8E3M4Type, Float8E3M4Type.get()),
@@ -644,6 +649,7 @@ def testTypeIDs():
 
         # CHECK: IntegerType(i16)
         # CHECK: IndexType(index)
+        # CHECK: Float4E2M1FNType(f4E2M1FN)
         # CHECK: Float6E2M3FNType(f6E2M3FN)
         # CHECK: Float6E3M2FNType(f6E3M2FN)
         # CHECK: Float8E3M4Type(f8E3M4)
@@ -725,6 +731,9 @@ def print_downcasted(typ):
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())
+        # CHECK: Float4E2M1FNType
+        # CHECK: Float4E2M1FNType(f4E2M1FN)
+        print_downcasted(Float4E2M1FNType.get())
         # CHECK: Float6E2M3FNType
         # CHECK: Float6E2M3FNType(f6E2M3FN)
         print_downcasted(Float6E2M3FNType.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index 350a0f7abea5a4..54d3d703640403 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -50,6 +50,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::CallSiteLoc": '"loc(callsite(...))"',
     "mlir::FusedLoc": '"loc(fused<...>[...])"',
     "mlir::UnknownLoc": '"loc(unknown)"',
+    "mlir::Float4E2M1FNType": '"f4E2M1FN"',
     "mlir::Float6E2M3FNType": '"f6E2M3FN"',
     "mlir::Float6E3M2FNType": '"f6E3M2FN"',
     "mlir::Float8E5M2Type": '"f8E5M2"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index 9df1944f6255d9..f7d916dfb57e2f 100644
--- a/mlir/utils/tree-sitter-mlir/grammar.js
+++ b/mlir/utils/tree-sitter-mlir/grammar.js
@@ -231,7 +231,7 @@ const common = {
       token(seq(choice('si', 'ui', 'i'), /[1-9]/, repeat(/[0-9]/))),
   float_type : $ => token(
       choice('f16', 'f32', 'f64', 'f80', 'f128', 'bf16', 'f8E3M4', 'f8E4M3FN',
-             'f8E4M3', 'f8E5M2', 'f6E2M3FN', 'f6E3M2FN')),
+             'f8E4M3', 'f8E5M2', 'f4E2M1FN', 'f6E2M3FN', 'f6E3M2FN')),
   index_type : $ => token('index'),
   none_typ...
[truncated]

Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me,

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@sergey-kozub sergey-kozub merged commit 2c58063 into llvm:main Sep 24, 2024
17 checks passed
augusto2112 pushed a commit to augusto2112/llvm-project that referenced this pull request Sep 26, 2024
This PR adds `f4E2M1FN` type to mlir.

`f4E2M1FN` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 4-bit floating point number with bit layout S1E2M1. Unlike
IEEE-754 types, there are no infinity or NaN values.

```c
f4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5
```

Related PRs:
- [PR-95392](llvm#95392) [APFloat]
Add APFloat support for FP4 data type
- [PR-105573](llvm#105573) [MLIR]
Add f6E3M2FN type - was used as a template for this PR
- [PR-107999](llvm#107999) [MLIR]
Add f6E2M3FN type
sergey-kozub added a commit that referenced this pull request Oct 4, 2024
This PR adds `f8E8M0FNU` type to MLIR.

`f8E8M0FNU` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 8-bit floating point number with bit layout S0E8M0. Unlike
IEEE-754 types, there are no infinity, denormals, zeros or negative
values.

```c
f8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0 − 127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111

Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1
```

Related PRs:
- [PR-107127](#107127)
[APFloat] Add APFloat support for E8M0 type
- [PR-105573](#105573) [MLIR]
Add f6E3M2FN type - was used as a template for this PR
- [PR-107999](#107999) [MLIR]
Add f6E2M3FN type
- [PR-108877](#108877) [MLIR]
Add f4E2M1FN type
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
This PR adds `f4E2M1FN` type to mlir.

`f4E2M1FN` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 4-bit floating point number with bit layout S1E2M1. Unlike
IEEE-754 types, there are no infinity or NaN values.

```c
f4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5
```

Related PRs:
- [PR-95392](llvm#95392) [APFloat]
Add APFloat support for FP4 data type
- [PR-105573](llvm#105573) [MLIR]
Add f6E3M2FN type - was used as a template for this PR
- [PR-107999](llvm#107999) [MLIR]
Add f6E2M3FN type
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
This PR adds `f8E8M0FNU` type to MLIR.

`f8E8M0FNU` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 8-bit floating point number with bit layout S0E8M0. Unlike
IEEE-754 types, there are no infinity, denormals, zeros or negative
values.

```c
f8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0 − 127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111

Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1
```

Related PRs:
- [PR-107127](llvm#107127)
[APFloat] Add APFloat support for E8M0 type
- [PR-105573](llvm#105573) [MLIR]
Add f6E3M2FN type - was used as a template for this PR
- [PR-107999](llvm#107999) [MLIR]
Add f6E2M3FN type
- [PR-108877](llvm#108877) [MLIR]
Add f4E2M1FN type
GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Oct 23, 2024
…U) (#2581)

This is a proposal to add MX (microscaling) floating point types to
StableHLO.

Related links:
- StableHLO [PR#2582](#2582)
Add MX floating point types (f4E2M1FN, f6E2M3FN, f6E3M2FN, f8E8M0FNU)
- LLVM [PR#95392](llvm/llvm-project#95392)
[APFloat] Add APFloat support for FP4 data type
- LLVM [PR#94735](llvm/llvm-project#94735)
[APFloat] Add APFloat support for FP6 data types
- LLVM [PR#107127](llvm/llvm-project#107127)
[APFloat] Add APFloat support for E8M0 type
- LLVM [PR#108877](llvm/llvm-project#108877)
[MLIR] Add f4E2M1FN type
- LLVM [PR#107999](llvm/llvm-project#107999)
[MLIR] Add f6E2M3FN type
- LLVM [PR#105573](llvm/llvm-project#105573)
[MLIR] Add f6E3M2FN type
- LLVM [PR#111028](llvm/llvm-project#111028)
[MLIR] Add f8E8M0FNU type
- JAX-ML [PR#181](jax-ml/ml_dtypes#181) Add
sub-byte data types: float4_e2m1fn, float6_e2m3fn, float6_e3m2fn
- JAX-ML [PR#166](jax-ml/ml_dtypes#181) Add
float8_e8m0_fnu (E8M0) OCP MX scale format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants