Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add f6E2M3FN type #107999

Merged
merged 1 commit into from
Sep 16, 2024
Merged

[MLIR] Add f6E2M3FN type #107999

merged 1 commit into from
Sep 16, 2024

Conversation

sergey-kozub
Copy link
Contributor

This PR adds f6E2M3FN type to mlir.

f6E2M3FN type is proposed in OpenCompute MX Specification. It defines a 6-bit floating point number with bit layout S1E2M3. Unlike IEEE-754 types, there are no infinity or NaN values.

f6E2M3FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 11 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.000
- Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5
- Min normal number: S.01.000 = ±2^(0) = ±1.0
- Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875
- Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125

Related PRs:

  • PR-94735 [APFloat] Add APFloat support for FP6 data types
  • PR-105573 [MLIR] Add f6E3M2FN type - was used as a template for this PR

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 10, 2024

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-ods

Author: Sergey Kozub (sergey-kozub)

Changes

This PR adds f6E2M3FN type to mlir.

f6E2M3FN type is proposed in OpenCompute MX Specification. It defines a 6-bit floating point number with bit layout S1E2M3. Unlike IEEE-754 types, there are no infinity or NaN values.

f6E2M3FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 11 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.000
- Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5
- Min normal number: S.01.000 = ±2^(0) = ±1.0
- Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875
- Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125

Related PRs:

  • PR-94735 [APFloat] Add APFloat support for FP6 data types
  • PR-105573 [MLIR] Add f6E3M2FN type - was used as a template for this PR

Full diff: https://github.com/llvm/llvm-project/pull/107999.diff

24 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+10-5)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+21)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+22)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+2)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+1-1)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 24531baecaa353..cc6da482a1c369 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -79,6 +79,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
 /// Returns the bitwidth of a floating-point type.
 MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);
 
+/// Returns the typeID of an Float6E2M3FN type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void);
+
+/// Checks whether the given type is an f6E2M3FN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type);
+
+/// Creates an f6E2M3FN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx);
+
 /// Returns the typeID of an Float6E3M2FN type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 5ac3a04b1c26ba..196d34e12d9b28 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
                        Attribute metadata = Attribute());
 
   // Types.
+  FloatType getFloat6E2M3FNType();
   FloatType getFloat6E3M2FNType();
   FloatType getFloat8E5M2Type();
   FloatType getFloat8E4M3Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 87ccc041f19758..f2231e9507570e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -67,6 +67,7 @@ class FloatType : public Type {
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E3M4(MLIRContext *ctx);
+  static FloatType getFloat6E2M3FN(MLIRContext *ctx);
   static FloatType getFloat6E3M2FN(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
@@ -414,11 +415,15 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return llvm::isa<Float6E3M2FNType, Float8E5M2Type, Float8E4M3Type,
-                   Float8E4M3FNType, Float8E5M2FNUZType, Float8E4M3FNUZType,
-                   Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
-                   Float16Type, FloatTF32Type, Float32Type, Float64Type,
-                   Float80Type, Float128Type>(type);
+  return llvm::isa<Float6E2M3FNType, Float6E3M2FNType, Float8E5M2Type,
+                   Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+                   Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
+                   Float64Type, Float80Type, Float128Type>(type);
+}
+
+inline FloatType FloatType::getFloat6E2M3FN(MLIRContext *ctx) {
+  return Float6E2M3FNType::get(ctx);
 }
 
 inline FloatType FloatType::getFloat6E3M2FN(MLIRContext *ctx) {
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b54d4ee4b7eb7a..09c2d34dc7dd1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -233,6 +233,27 @@ def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float6E2M3FNType
+
+def Builtin_Float6E2M3FN : Builtin_FloatType<"Float6E2M3FN", "f6E2M3FN"> {
+  let summary = "6-bit floating point with 3 bits exponent and 2 bit mantissa";
+  let description = [{
+    An 6-bit floating point type with 1 sign bit, 2 bits exponent and 3 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E2M3
+      * exponent bias: 1
+      * infinities: Not supported
+      * NaNs: Not supported
+      * denormals when exponent is 0
+
+    Open Compute Project (OCP) microscaling formats (MX) specification:
+    https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Float6E3M2FNType
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 09eab50f53a540..3cc1c95f1ed37a 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -344,6 +344,8 @@ def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
 def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
              BuildableType<"$_builder.getFloat8E3M4Type()">;
+def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
+               BuildableType<"$_builder.getFloat6E2M3FNType()">;
 def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
                BuildableType<"$_builder.getFloat6E3M2FNType()">;
 
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index b6a307fd7cb0fe..8b6f365fbda02e 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -125,6 +125,7 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
+  bool isFloat6E2M3FN() const;
   bool isFloat6E3M2FN() const;
   bool isFloat8E5M2() const;
   bool isFloat8E4M3() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index fa18cbe9e2b901..6ae64a17d1fadb 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -101,6 +101,7 @@ TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f8E4M3B11FNUZ)
 TOK_KEYWORD(f8E3M4)
+TOK_KEYWORD(f6E2M3FN)
 TOK_KEYWORD(f6E3M2FN)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 05276031211fa9..a3798ca8d90b1b 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -39,6 +39,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_tuple:
   case Token::kw_vector:
   case Token::inttype:
+  case Token::kw_f6E2M3FN:
   case Token::kw_f6E3M2FN:
   case Token::kw_f8E5M2:
   case Token::kw_f8E4M3:
@@ -304,6 +305,9 @@ Type Parser::parseNonFunctionType() {
   }
 
   // float-type
+  case Token::kw_f6E2M3FN:
+    consumeToken(Token::kw_f6E2M3FN);
+    return builder.getFloat6E2M3FNType();
   case Token::kw_f6E3M2FN:
     consumeToken(Token::kw_f6E3M2FN);
     return builder.getFloat6E3M2FNType();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 1cb429d9ca7b2d..6b64bc3c9d6f63 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -124,6 +124,27 @@ class PyFloatType : public PyConcreteType<PyFloatType> {
   }
 };
 
+/// Floating Point Type subclass - Float6E2M3FNType.
+class PyFloat6E2M3FNType
+    : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat6E2M3FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float6E2M3FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+          return PyFloat6E2M3FNType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float6_e2m3fn type.");
+  }
+};
+
 /// Floating Point Type subclass - Float6E3M2FNType.
 class PyFloat6E3M2FNType
     : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
@@ -901,6 +922,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyIntegerType::bind(m);
   PyFloatType::bind(m);
   PyIndexType::bind(m);
+  PyFloat6E2M3FNType::bind(m);
   PyFloat6E3M2FNType::bind(m);
   PyFloat8E4M3FNType::bind(m);
   PyFloat8E5M2Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 254650d66a67e6..f943bf726b172c 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -85,6 +85,18 @@ unsigned mlirFloatTypeGetWidth(MlirType type) {
   return llvm::cast<FloatType>(unwrap(type)).getWidth();
 }
 
+MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
+  return wrap(Float6E2M3FNType::getTypeID());
+}
+
+bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
+  return unwrap(type).isFloat6E2M3FN();
+}
+
+MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat6E2M3FN(unwrap(ctx)));
+}
+
 MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
   return wrap(Float6E3M2FNType::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index b2c54bb3212edb..51a1b91338c6a0 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -250,7 +250,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
       type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
-      type.isFloat6E3M2FN())
+      type.isFloat6E2M3FN() || type.isFloat6E3M2FN())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index a5ee6edc6320d5..5e5e10b1fa1c2b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -55,6 +55,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
                                                StringRef name) {
   Builder b(ctx);
   return llvm::StringSwitch<std::optional<FloatType>>(name)
+      .Case("f6E2M3FN", b.getFloat6E2M3FNType())
       .Case("f6E3M2FN", b.getFloat6E3M2FNType())
       .Case("f8E5M2", b.getFloat8E5M2Type())
       .Case("f8E4M3", b.getFloat8E4M3Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5142b462820786..c7ed158aabb6e7 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2575,6 +2575,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
                            opaqueTy.getTypeData());
       })
       .Case<IndexType>([&](Type) { os << "index"; })
+      .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; })
       .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
       .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
       .Case<Float8E4M3Type>([&](Type) { os << "f8E4M3"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 71f622b02adee0..144a13df2179b7 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,6 +34,10 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
 // Types.
 //===----------------------------------------------------------------------===//
 
+FloatType Builder::getFloat6E2M3FNType() {
+  return FloatType::getFloat6E2M3FN(context);
+}
+
 FloatType Builder::getFloat6E3M2FNType() {
   return FloatType::getFloat6E3M2FN(context);
 }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index e46b6a4a6bb693..702d98ec31427b 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -101,6 +101,8 @@ unsigned FloatType::getWidth() {
 
 /// Returns the floating semantics for the given type.
 const llvm::fltSemantics &FloatType::getFloatSemantics() {
+  if (llvm::isa<Float6E2M3FNType>(*this))
+    return APFloat::Float6E2M3FN();
   if (llvm::isa<Float6E3M2FNType>(*this))
     return APFloat::Float6E3M2FN();
   if (llvm::isa<Float8E5M2Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2851e6457ea3cb..1684566626886c 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -221,6 +221,7 @@ class MLIRContextImpl {
   llvm::DenseMap<StringRef, AbstractType *> nameToType;
 
   /// Cached Type Instances.
+  Float6E2M3FNType f6E2M3FNTy;
   Float6E3M2FNType f6E3M2FNTy;
   Float8E5M2Type f8E5M2Ty;
   Float8E4M3Type f8E4M3Ty;
@@ -314,6 +315,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
 
   //// Types.
   /// Floating-point Types.
+  impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
   impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
   impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
   impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
@@ -1015,6 +1017,9 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
 /// This should not be used directly.
 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 
+Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
+  return context->getImpl().f6E2M3FNTy;
+}
 Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
   return context->getImpl().f6E3M2FNTy;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index fa093664cf77f1..c828fd3766eaa7 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -34,6 +34,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
+bool Type::isFloat6E2M3FN() const { return llvm::isa<Float6E2M3FNType>(*this); }
 bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); }
 bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
 bool Type::isFloat8E4M3() const { return llvm::isa<Float8E4M3Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 7b4fac7275bfc6..17a02b0bd445a7 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float6E2M3FNType",
     "Float6E3M2FNType",
     "Float8E3M4Type",
     "Float8E4M3B11FNUZType",
@@ -1540,6 +1541,19 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
+class Float6E2M3FNType(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Optional[Context] = None) -> Float6E2M3FNType:
+        """
+        Create a float6_e2m3fn type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class Float6E3M2FNType(FloatType):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index 0c6ece91d8b94a..4be425f220c978 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
     F16Type,
     F32Type,
     F64Type,
+    Float6E2M3FNType,
     Float6E3M2FNType,
     Float8E3M4Type,
     Float8E4M3B11FNUZType,
@@ -75,6 +76,7 @@ def ui(width):
 f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
 f8E3M4 = lambda: Float8E3M4Type.get()
+f6E2M3FN = lambda: Float6E2M3FNType.get()
 f6E3M2FN = lambda: Float6E3M2FNType.get()
 
 none = lambda: NoneType.get()
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 38cbf9d5d2b579..23dbf0c292c2c3 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -36,6 +36,10 @@ func.func @any_attr_of_fail() {
 //===----------------------------------------------------------------------===//
 
 func.func @float_attrs_pass() {
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f6E2M3FN
+    float_attr = 2. : f6E2M3FN
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f6E3M2FN
     float_attr = 2. : f6E3M2FN
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 04be037978c8f6..7eca1a40373054 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -42,6 +42,9 @@ llvm.mlir.global internal @int_global_undef() : i64
 // CHECK: @externally_initialized_global = internal externally_initialized global i32 0
 llvm.mlir.global internal @externally_initialized_global(0 : i32) {externally_initialized} : i32
 
+// CHECK: @f6E2M3FN_global_as_i6 = internal global i6 12
+llvm.mlir.global internal @f6E2M3FN_global_as_i6(1.5 : f6E2M3FN) : i6
+
 // CHECK: @f6E3M2FN_global_as_i6 = internal global i6 14
 llvm.mlir.global internal @f6E3M2FN_global_as_i6(1.5 : f6E3M2FN) : i6
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index b72ef4de0bd6dd..bc3ba4cd0b1448 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
 def testFloatTypeSubclasses():
     ctx = Context()
     # CHECK: True
+    print(isinstance(Type.parse("f6E2M3FN", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
@@ -235,6 +237,8 @@ def testIndexType():
 @run
 def testFloatType():
     with Context():
+        # CHECK: float: f6E2M3FN
+        print("float:", Float6E2M3FNType.get())
         # CHECK: float: f6E3M2FN
         print("float:", Float6E3M2FNType.get())
         # CHECK: float: f8E3M4
@@ -613,6 +617,7 @@ def testTypeIDs():
         types = [
             (IntegerType, IntegerType.get_signless(16)),
             (IndexType, IndexType.get()),
+            (Float6E2M3FNType, Float6E2M3FNType.get()),
             (Float6E3M2FNType, Float6E3M2FNType.get()),
             (Float8E3M4Type, Float8E3M4Type.get()),
             (Float8E4M3Type, Float8E4M3Type.get()),
@@ -639,6 +644,7 @@ def testTypeIDs():
 
         # CHECK: IntegerType(i16)
         # CHECK: IndexType(index)
+        # CHECK: Float6E2M3FNType(f6E2M3FN)
         # CHECK: Float6E3M2FNType(f6E3M2FN)
         # CHECK: Float8E3M4Type(f8E3M4)
         # CHECK: Float8E4M3Type(f8E4M3)
@@ -719,6 +725,9 @@ def print_downcasted(typ):
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())
+        # CHECK: Float6E2M3FNType
+        # CHECK: Float6E2M3FNType(f6E2M3FN)
+        print_downcasted(Float6E2M3FNType.get())
         # CHECK: Float6E3M2FNType
         # CHECK: Float6E3M2FNType(f6E3M2FN)
         print_downcasted(Float6E3M2FNType.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index fed149d03ecf31..350a0f7abea5a4 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -50,6 +50,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::CallSiteLoc": '"loc(callsite(...))"',
     "mlir::FusedLoc": '"loc(fused<...>[...])"',
     "mlir::UnknownLoc": '"loc(unknown)"',
+    "mlir::Float6E2M3FNType": '"f6E2M3FN"',
     "mlir::Float6E3M2FNType": '"f6E3M2FN"',
     "mlir::Float8E5M2Type": '"f8E5M2"',
     "mlir::Float8E4M3Type": '"f8E4M3"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index d2c66714b4b118..9df1944f6255d9 100644
--- a/mlir/utils/tree-sitter-mlir/grammar.js
+++ b/mlir/utils/tree-sitter-mlir/grammar.js
@@ -231,7 +231,7 @@ const common = {
       token(seq(choice('si', 'ui', 'i'), /[1-9]/, repeat(/[0-9]/))),
   float_type : $ => token(
       choice('f16', 'f32', 'f64', 'f80', 'f128', 'bf16', 'f8E3M4', 'f8E4M3FN',
-             'f8E4M3', 'f8E5M2', 'f6E3M2FN')),
+             'f8E4M3', 'f8E5M2', 'f6E2M3FN', 'f6E3M2FN')),
   index_type : $ => token('index'),
   none_type : $ => token('none'),
   complex_type : $ => seq(token('complex'), '<', $._prim_type, '>'),

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 10, 2024

@llvm/pr-subscribers-mlir-core

Author: Sergey Kozub (sergey-kozub)

Changes

This PR adds f6E2M3FN type to mlir.

f6E2M3FN type is proposed in OpenCompute MX Specification. It defines a 6-bit floating point number with bit layout S1E2M3. Unlike IEEE-754 types, there are no infinity or NaN values.

f6E2M3FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 11 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.000
- Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5
- Min normal number: S.01.000 = ±2^(0) = ±1.0
- Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875
- Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125

Related PRs:

  • PR-94735 [APFloat] Add APFloat support for FP6 data types
  • PR-105573 [MLIR] Add f6E3M2FN type - was used as a template for this PR

Full diff: https://github.com/llvm/llvm-project/pull/107999.diff

24 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+10-5)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+21)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+22)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+2)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+1-1)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 24531baecaa353..cc6da482a1c369 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -79,6 +79,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
 /// Returns the bitwidth of a floating-point type.
 MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);
 
+/// Returns the typeID of an Float6E2M3FN type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void);
+
+/// Checks whether the given type is an f6E2M3FN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type);
+
+/// Creates an f6E2M3FN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx);
+
 /// Returns the typeID of an Float6E3M2FN type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 5ac3a04b1c26ba..196d34e12d9b28 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
                        Attribute metadata = Attribute());
 
   // Types.
+  FloatType getFloat6E2M3FNType();
   FloatType getFloat6E3M2FNType();
   FloatType getFloat8E5M2Type();
   FloatType getFloat8E4M3Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 87ccc041f19758..f2231e9507570e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -67,6 +67,7 @@ class FloatType : public Type {
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E3M4(MLIRContext *ctx);
+  static FloatType getFloat6E2M3FN(MLIRContext *ctx);
   static FloatType getFloat6E3M2FN(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
@@ -414,11 +415,15 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return llvm::isa<Float6E3M2FNType, Float8E5M2Type, Float8E4M3Type,
-                   Float8E4M3FNType, Float8E5M2FNUZType, Float8E4M3FNUZType,
-                   Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
-                   Float16Type, FloatTF32Type, Float32Type, Float64Type,
-                   Float80Type, Float128Type>(type);
+  return llvm::isa<Float6E2M3FNType, Float6E3M2FNType, Float8E5M2Type,
+                   Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+                   Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
+                   Float64Type, Float80Type, Float128Type>(type);
+}
+
+inline FloatType FloatType::getFloat6E2M3FN(MLIRContext *ctx) {
+  return Float6E2M3FNType::get(ctx);
 }
 
 inline FloatType FloatType::getFloat6E3M2FN(MLIRContext *ctx) {
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b54d4ee4b7eb7a..09c2d34dc7dd1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -233,6 +233,27 @@ def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float6E2M3FNType
+
+def Builtin_Float6E2M3FN : Builtin_FloatType<"Float6E2M3FN", "f6E2M3FN"> {
+  let summary = "6-bit floating point with 3 bits exponent and 2 bit mantissa";
+  let description = [{
+    An 6-bit floating point type with 1 sign bit, 2 bits exponent and 3 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E2M3
+      * exponent bias: 1
+      * infinities: Not supported
+      * NaNs: Not supported
+      * denormals when exponent is 0
+
+    Open Compute Project (OCP) microscaling formats (MX) specification:
+    https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Float6E3M2FNType
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 09eab50f53a540..3cc1c95f1ed37a 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -344,6 +344,8 @@ def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
 def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
              BuildableType<"$_builder.getFloat8E3M4Type()">;
+def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
+               BuildableType<"$_builder.getFloat6E2M3FNType()">;
 def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
                BuildableType<"$_builder.getFloat6E3M2FNType()">;
 
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index b6a307fd7cb0fe..8b6f365fbda02e 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -125,6 +125,7 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
+  bool isFloat6E2M3FN() const;
   bool isFloat6E3M2FN() const;
   bool isFloat8E5M2() const;
   bool isFloat8E4M3() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index fa18cbe9e2b901..6ae64a17d1fadb 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -101,6 +101,7 @@ TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f8E4M3B11FNUZ)
 TOK_KEYWORD(f8E3M4)
+TOK_KEYWORD(f6E2M3FN)
 TOK_KEYWORD(f6E3M2FN)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 05276031211fa9..a3798ca8d90b1b 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -39,6 +39,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_tuple:
   case Token::kw_vector:
   case Token::inttype:
+  case Token::kw_f6E2M3FN:
   case Token::kw_f6E3M2FN:
   case Token::kw_f8E5M2:
   case Token::kw_f8E4M3:
@@ -304,6 +305,9 @@ Type Parser::parseNonFunctionType() {
   }
 
   // float-type
+  case Token::kw_f6E2M3FN:
+    consumeToken(Token::kw_f6E2M3FN);
+    return builder.getFloat6E2M3FNType();
   case Token::kw_f6E3M2FN:
     consumeToken(Token::kw_f6E3M2FN);
     return builder.getFloat6E3M2FNType();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 1cb429d9ca7b2d..6b64bc3c9d6f63 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -124,6 +124,27 @@ class PyFloatType : public PyConcreteType<PyFloatType> {
   }
 };
 
+/// Floating Point Type subclass - Float6E2M3FNType.
+class PyFloat6E2M3FNType
+    : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat6E2M3FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float6E2M3FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+          return PyFloat6E2M3FNType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float6_e2m3fn type.");
+  }
+};
+
 /// Floating Point Type subclass - Float6E3M2FNType.
 class PyFloat6E3M2FNType
     : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
@@ -901,6 +922,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyIntegerType::bind(m);
   PyFloatType::bind(m);
   PyIndexType::bind(m);
+  PyFloat6E2M3FNType::bind(m);
   PyFloat6E3M2FNType::bind(m);
   PyFloat8E4M3FNType::bind(m);
   PyFloat8E5M2Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 254650d66a67e6..f943bf726b172c 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -85,6 +85,18 @@ unsigned mlirFloatTypeGetWidth(MlirType type) {
   return llvm::cast<FloatType>(unwrap(type)).getWidth();
 }
 
+MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
+  return wrap(Float6E2M3FNType::getTypeID());
+}
+
+bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
+  return unwrap(type).isFloat6E2M3FN();
+}
+
+MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat6E2M3FN(unwrap(ctx)));
+}
+
 MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
   return wrap(Float6E3M2FNType::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index b2c54bb3212edb..51a1b91338c6a0 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -250,7 +250,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
       type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
-      type.isFloat6E3M2FN())
+      type.isFloat6E2M3FN() || type.isFloat6E3M2FN())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index a5ee6edc6320d5..5e5e10b1fa1c2b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -55,6 +55,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
                                                StringRef name) {
   Builder b(ctx);
   return llvm::StringSwitch<std::optional<FloatType>>(name)
+      .Case("f6E2M3FN", b.getFloat6E2M3FNType())
       .Case("f6E3M2FN", b.getFloat6E3M2FNType())
       .Case("f8E5M2", b.getFloat8E5M2Type())
       .Case("f8E4M3", b.getFloat8E4M3Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5142b462820786..c7ed158aabb6e7 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2575,6 +2575,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
                            opaqueTy.getTypeData());
       })
       .Case<IndexType>([&](Type) { os << "index"; })
+      .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; })
       .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
       .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
       .Case<Float8E4M3Type>([&](Type) { os << "f8E4M3"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 71f622b02adee0..144a13df2179b7 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,6 +34,10 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
 // Types.
 //===----------------------------------------------------------------------===//
 
+FloatType Builder::getFloat6E2M3FNType() {
+  return FloatType::getFloat6E2M3FN(context);
+}
+
 FloatType Builder::getFloat6E3M2FNType() {
   return FloatType::getFloat6E3M2FN(context);
 }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index e46b6a4a6bb693..702d98ec31427b 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -101,6 +101,8 @@ unsigned FloatType::getWidth() {
 
 /// Returns the floating semantics for the given type.
 const llvm::fltSemantics &FloatType::getFloatSemantics() {
+  if (llvm::isa<Float6E2M3FNType>(*this))
+    return APFloat::Float6E2M3FN();
   if (llvm::isa<Float6E3M2FNType>(*this))
     return APFloat::Float6E3M2FN();
   if (llvm::isa<Float8E5M2Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2851e6457ea3cb..1684566626886c 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -221,6 +221,7 @@ class MLIRContextImpl {
   llvm::DenseMap<StringRef, AbstractType *> nameToType;
 
   /// Cached Type Instances.
+  Float6E2M3FNType f6E2M3FNTy;
   Float6E3M2FNType f6E3M2FNTy;
   Float8E5M2Type f8E5M2Ty;
   Float8E4M3Type f8E4M3Ty;
@@ -314,6 +315,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
 
   //// Types.
   /// Floating-point Types.
+  impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
   impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
   impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
   impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
@@ -1015,6 +1017,9 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
 /// This should not be used directly.
 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 
+Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
+  return context->getImpl().f6E2M3FNTy;
+}
 Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
   return context->getImpl().f6E3M2FNTy;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index fa093664cf77f1..c828fd3766eaa7 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -34,6 +34,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
+bool Type::isFloat6E2M3FN() const { return llvm::isa<Float6E2M3FNType>(*this); }
 bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); }
 bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
 bool Type::isFloat8E4M3() const { return llvm::isa<Float8E4M3Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 7b4fac7275bfc6..17a02b0bd445a7 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float6E2M3FNType",
     "Float6E3M2FNType",
     "Float8E3M4Type",
     "Float8E4M3B11FNUZType",
@@ -1540,6 +1541,19 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
+class Float6E2M3FNType(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Optional[Context] = None) -> Float6E2M3FNType:
+        """
+        Create a float6_e2m3fn type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class Float6E3M2FNType(FloatType):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index 0c6ece91d8b94a..4be425f220c978 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
     F16Type,
     F32Type,
     F64Type,
+    Float6E2M3FNType,
     Float6E3M2FNType,
     Float8E3M4Type,
     Float8E4M3B11FNUZType,
@@ -75,6 +76,7 @@ def ui(width):
 f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
 f8E3M4 = lambda: Float8E3M4Type.get()
+f6E2M3FN = lambda: Float6E2M3FNType.get()
 f6E3M2FN = lambda: Float6E3M2FNType.get()
 
 none = lambda: NoneType.get()
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 38cbf9d5d2b579..23dbf0c292c2c3 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -36,6 +36,10 @@ func.func @any_attr_of_fail() {
 //===----------------------------------------------------------------------===//
 
 func.func @float_attrs_pass() {
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f6E2M3FN
+    float_attr = 2. : f6E2M3FN
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f6E3M2FN
     float_attr = 2. : f6E3M2FN
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 04be037978c8f6..7eca1a40373054 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -42,6 +42,9 @@ llvm.mlir.global internal @int_global_undef() : i64
 // CHECK: @externally_initialized_global = internal externally_initialized global i32 0
 llvm.mlir.global internal @externally_initialized_global(0 : i32) {externally_initialized} : i32
 
+// CHECK: @f6E2M3FN_global_as_i6 = internal global i6 12
+llvm.mlir.global internal @f6E2M3FN_global_as_i6(1.5 : f6E2M3FN) : i6
+
 // CHECK: @f6E3M2FN_global_as_i6 = internal global i6 14
 llvm.mlir.global internal @f6E3M2FN_global_as_i6(1.5 : f6E3M2FN) : i6
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index b72ef4de0bd6dd..bc3ba4cd0b1448 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
 def testFloatTypeSubclasses():
     ctx = Context()
     # CHECK: True
+    print(isinstance(Type.parse("f6E2M3FN", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
@@ -235,6 +237,8 @@ def testIndexType():
 @run
 def testFloatType():
     with Context():
+        # CHECK: float: f6E2M3FN
+        print("float:", Float6E2M3FNType.get())
         # CHECK: float: f6E3M2FN
         print("float:", Float6E3M2FNType.get())
         # CHECK: float: f8E3M4
@@ -613,6 +617,7 @@ def testTypeIDs():
         types = [
             (IntegerType, IntegerType.get_signless(16)),
             (IndexType, IndexType.get()),
+            (Float6E2M3FNType, Float6E2M3FNType.get()),
             (Float6E3M2FNType, Float6E3M2FNType.get()),
             (Float8E3M4Type, Float8E3M4Type.get()),
             (Float8E4M3Type, Float8E4M3Type.get()),
@@ -639,6 +644,7 @@ def testTypeIDs():
 
         # CHECK: IntegerType(i16)
         # CHECK: IndexType(index)
+        # CHECK: Float6E2M3FNType(f6E2M3FN)
         # CHECK: Float6E3M2FNType(f6E3M2FN)
         # CHECK: Float8E3M4Type(f8E3M4)
         # CHECK: Float8E4M3Type(f8E4M3)
@@ -719,6 +725,9 @@ def print_downcasted(typ):
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())
+        # CHECK: Float6E2M3FNType
+        # CHECK: Float6E2M3FNType(f6E2M3FN)
+        print_downcasted(Float6E2M3FNType.get())
         # CHECK: Float6E3M2FNType
         # CHECK: Float6E3M2FNType(f6E3M2FN)
         print_downcasted(Float6E3M2FNType.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index fed149d03ecf31..350a0f7abea5a4 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -50,6 +50,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::CallSiteLoc": '"loc(callsite(...))"',
     "mlir::FusedLoc": '"loc(fused<...>[...])"',
     "mlir::UnknownLoc": '"loc(unknown)"',
+    "mlir::Float6E2M3FNType": '"f6E2M3FN"',
     "mlir::Float6E3M2FNType": '"f6E3M2FN"',
     "mlir::Float8E5M2Type": '"f8E5M2"',
     "mlir::Float8E4M3Type": '"f8E4M3"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index d2c66714b4b118..9df1944f6255d9 100644
--- a/mlir/utils/tree-sitter-mlir/grammar.js
+++ b/mlir/utils/tree-sitter-mlir/grammar.js
@@ -231,7 +231,7 @@ const common = {
       token(seq(choice('si', 'ui', 'i'), /[1-9]/, repeat(/[0-9]/))),
   float_type : $ => token(
       choice('f16', 'f32', 'f64', 'f80', 'f128', 'bf16', 'f8E3M4', 'f8E4M3FN',
-             'f8E4M3', 'f8E5M2', 'f6E3M2FN')),
+             'f8E4M3', 'f8E5M2', 'f6E2M3FN', 'f6E3M2FN')),
   index_type : $ => token('index'),
   none_type : $ => token('none'),
   complex_type : $ => seq(token('complex'), '<', $._prim_type, '>'),

// Float6E2M3FNType

def Builtin_Float6E2M3FN : Builtin_FloatType<"Float6E2M3FN", "f6E2M3FN"> {
let summary = "6-bit floating point with 3 bits exponent and 2 bit mantissa";
Copy link
Member

@apivovarov apivovarov Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2-bit exponent and 3-bit mantissa.

Seems that correct spelling for adjectives is 2-bit , 3-bit, etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

This PR adds `f6E2M3FN` type to mlir.

`f6E2M3FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 6-bit floating point number with bit layout S1E2M3. Unlike IEEE-754 types, there are no infinity or NaN values.

```c
f6E2M3FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.000
- Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5
- Min normal number: S.01.000 = ±2^(0) = ±1.0
- Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875
- Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125
```

Related PRs:
- [PR-94735](llvm#94735) [APFloat] Add APFloat support for FP6 data types
- [PR-105573](llvm#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR
@sergey-kozub sergey-kozub merged commit 73d83f2 into llvm:main Sep 16, 2024
8 checks passed
sergey-kozub added a commit that referenced this pull request Sep 24, 2024
This PR adds `f4E2M1FN` type to mlir.

`f4E2M1FN` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 4-bit floating point number with bit layout S1E2M1. Unlike
IEEE-754 types, there are no infinity or NaN values.

```c
f4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5
```

Related PRs:
- [PR-95392](#95392) [APFloat]
Add APFloat support for FP4 data type
- [PR-105573](#105573) [MLIR]
Add f6E3M2FN type - was used as a template for this PR
- [PR-107999](#107999) [MLIR]
Add f6E2M3FN type
augusto2112 pushed a commit to augusto2112/llvm-project that referenced this pull request Sep 26, 2024
This PR adds `f4E2M1FN` type to mlir.

`f4E2M1FN` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 4-bit floating point number with bit layout S1E2M1. Unlike
IEEE-754 types, there are no infinity or NaN values.

```c
f4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5
```

Related PRs:
- [PR-95392](llvm#95392) [APFloat]
Add APFloat support for FP4 data type
- [PR-105573](llvm#105573) [MLIR]
Add f6E3M2FN type - was used as a template for this PR
- [PR-107999](llvm#107999) [MLIR]
Add f6E2M3FN type
sergey-kozub added a commit that referenced this pull request Oct 4, 2024
This PR adds `f8E8M0FNU` type to MLIR.

`f8E8M0FNU` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 8-bit floating point number with bit layout S0E8M0. Unlike
IEEE-754 types, there are no infinity, denormals, zeros or negative
values.

```c
f8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0 − 127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111

Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1
```

Related PRs:
- [PR-107127](#107127)
[APFloat] Add APFloat support for E8M0 type
- [PR-105573](#105573) [MLIR]
Add f6E3M2FN type - was used as a template for this PR
- [PR-107999](#107999) [MLIR]
Add f6E2M3FN type
- [PR-108877](#108877) [MLIR]
Add f4E2M1FN type
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
This PR adds `f4E2M1FN` type to mlir.

`f4E2M1FN` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 4-bit floating point number with bit layout S1E2M1. Unlike
IEEE-754 types, there are no infinity or NaN values.

```c
f4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5
```

Related PRs:
- [PR-95392](llvm#95392) [APFloat]
Add APFloat support for FP4 data type
- [PR-105573](llvm#105573) [MLIR]
Add f6E3M2FN type - was used as a template for this PR
- [PR-107999](llvm#107999) [MLIR]
Add f6E2M3FN type
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
This PR adds `f8E8M0FNU` type to MLIR.

`f8E8M0FNU` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 8-bit floating point number with bit layout S0E8M0. Unlike
IEEE-754 types, there are no infinity, denormals, zeros or negative
values.

```c
f8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0 − 127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111

Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1
```

Related PRs:
- [PR-107127](llvm#107127)
[APFloat] Add APFloat support for E8M0 type
- [PR-105573](llvm#105573) [MLIR]
Add f6E3M2FN type - was used as a template for this PR
- [PR-107999](llvm#107999) [MLIR]
Add f6E2M3FN type
- [PR-108877](llvm#108877) [MLIR]
Add f4E2M1FN type
GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Oct 23, 2024
…U) (#2581)

This is a proposal to add MX (microscaling) floating point types to
StableHLO.

Related links:
- StableHLO [PR#2582](#2582)
Add MX floating point types (f4E2M1FN, f6E2M3FN, f6E3M2FN, f8E8M0FNU)
- LLVM [PR#95392](llvm/llvm-project#95392)
[APFloat] Add APFloat support for FP4 data type
- LLVM [PR#94735](llvm/llvm-project#94735)
[APFloat] Add APFloat support for FP6 data types
- LLVM [PR#107127](llvm/llvm-project#107127)
[APFloat] Add APFloat support for E8M0 type
- LLVM [PR#108877](llvm/llvm-project#108877)
[MLIR] Add f4E2M1FN type
- LLVM [PR#107999](llvm/llvm-project#107999)
[MLIR] Add f6E2M3FN type
- LLVM [PR#105573](llvm/llvm-project#105573)
[MLIR] Add f6E3M2FN type
- LLVM [PR#111028](llvm/llvm-project#111028)
[MLIR] Add f8E8M0FNU type
- JAX-ML [PR#181](jax-ml/ml_dtypes#181) Add
sub-byte data types: float4_e2m1fn, float6_e2m3fn, float6_e3m2fn
- JAX-ML [PR#166](jax-ml/ml_dtypes#181) Add
float8_e8m0_fnu (E8M0) OCP MX scale format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants