Skip to content

Commit

Permalink
Merge pull request #171 from apivovarov:e3m4
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668214541
  • Loading branch information
The ml_dtypes Authors committed Aug 28, 2024
2 parents f053b3c + 4a03c71 commit 82f3a61
Show file tree
Hide file tree
Showing 9 changed files with 377 additions and 17 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

* Added new 8-bit float type following IEEE 754 convention:
`ml_dtypes.float8_e4m3`.
* Added new 8-bit float types following IEEE 754 convention:
`ml_dtypes.float8_e4m3` and `ml_dtypes.float8_e3m4`.
* Fix outputs of float `divmod` and `floor_divide` when denominator is zero.

## [0.4.0] - 2024-04-1
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format
- `float8_*`: several experimental 8-bit floating point representations
including:
* `float8_e3m4`
* `float8_e4m3`
* `float8_e4m3b11fnuz`
* `float8_e4m3fn`
Expand Down Expand Up @@ -65,6 +66,10 @@ A `bfloat16` number is a single-precision float truncated at 16 bits.

Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.

### `float8_e3m4`

Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf.

### `float8_e4m3`

Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf.
Expand Down
3 changes: 3 additions & 0 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"__version__",
"bfloat16",
"finfo",
"float8_e3m4",
"float8_e4m3",
"float8_e4m3b11fnuz",
"float8_e4m3fn",
Expand All @@ -35,6 +36,7 @@
from ml_dtypes._finfo import finfo
from ml_dtypes._iinfo import iinfo
from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float8_e3m4
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
Expand All @@ -48,6 +50,7 @@
import numpy as np

bfloat16: Type[np.generic]
float8_e3m4: Type[np.generic]
float8_e4m3: Type[np.generic]
float8_e4m3b11fnuz: Type[np.generic]
float8_e4m3fn: Type[np.generic]
Expand Down
66 changes: 65 additions & 1 deletion ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Dict

from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float8_e3m4
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
Expand All @@ -26,6 +27,7 @@
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
_float8_e3m4_dtype = np.dtype(float8_e3m4)
_float8_e4m3_dtype = np.dtype(float8_e4m3)
_float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype = np.dtype(float8_e4m3fn)
Expand All @@ -43,12 +45,21 @@ def __init__(self):
self.smallest_subnormal = bfloat16(smallest_subnormal)


class _Float8E3m4MachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p-2")
self.smallest_normal = float8_e3m4(smallest_normal)
smallest_subnormal = float.fromhex("0x0.1p-2")
self.smallest_subnormal = float8_e3m4(smallest_subnormal)


class _Float8E4m3MachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p-6")
self.smallest_normal = float8_e4m3(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-9")
smallest_subnormal = float.fromhex("0x0.2p-6")
self.smallest_subnormal = float8_e4m3(smallest_subnormal)


Expand Down Expand Up @@ -146,6 +157,51 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e3m4_finfo():
def float_to_str(f):
return "%6.2e" % float(f)

tiny = float.fromhex("0x1p-2") # 1/4 min normal
resolution = 0.1
eps = float.fromhex("0x1p-4") # 1/16
epsneg = float.fromhex("0x1p-5") # 1/32
max_ = float.fromhex("0x1.Fp3") # 15.5 max normal

obj = object.__new__(np.finfo)
obj.dtype = _float8_e3m4_dtype
obj.bits = 8
obj.eps = float8_e3m4(eps)
obj.epsneg = float8_e3m4(epsneg)
obj.machep = -4
obj.negep = -5
obj.max = float8_e3m4(max_)
obj.min = float8_e3m4(-max_)
obj.nexp = 3
obj.nmant = 4
obj.iexp = obj.nexp
obj.maxexp = 4
obj.minexp = -2
obj.precision = 1
obj.resolution = float8_e3m4(resolution)
# pylint: disable=protected-access
obj._machar = _Float8E3m4MachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = float8_e3m4(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal

obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max_)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e4m3_finfo():
def float_to_str(f):
Expand Down Expand Up @@ -425,6 +481,14 @@ def __new__(cls, dtype):
if _bfloat16_dtype not in cls._finfo_cache:
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
return cls._finfo_cache[_bfloat16_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e3m4"
or dtype == _float8_e3m4_dtype
):
if _float8_e3m4_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e3m4_dtype] = cls._float8_e3m4_finfo()
return cls._finfo_cache[_float8_e3m4_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3"
Expand Down
29 changes: 29 additions & 0 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ struct TypeDescriptor<bfloat16> : CustomFloatType<bfloat16> {
static constexpr char kNpyDescrByteorder = '=';
};

template <>
struct TypeDescriptor<float8_e3m4> : CustomFloatType<float8_e3m4> {
typedef float8_e3m4 T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e3m4";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e3m4";
static constexpr const char* kTpDoc = "float8_e3m4 floating-point values";
// Set e3m4 kind as Void since kind=f (float) with itemsize=1 is used by e5m2
static constexpr char kNpyDescrKind = 'V'; // Void
static constexpr char kNpyDescrType = '3';
static constexpr char kNpyDescrByteorder = '='; // Native byte order
};

template <>
struct TypeDescriptor<float8_e4m3> : CustomFloatType<float8_e4m3> {
typedef float8_e4m3 T;
Expand Down Expand Up @@ -283,6 +297,9 @@ bool Initialize() {
if (!RegisterFloatDtype<bfloat16>(numpy.get())) {
return false;
}
if (!RegisterFloatDtype<float8_e3m4>(numpy.get())) {
return false;
}
if (!RegisterFloatDtype<float8_e4m3>(numpy.get())) {
return false;
}
Expand Down Expand Up @@ -342,6 +359,13 @@ bool Initialize() {
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3fnuz, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3fn, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e5m2, float>();
success &= RegisterTwoWayCustomCast<float8_e3m4, bfloat16, float>();
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e4m3b11fnuz, float>();
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e5m2fnuz, float>();
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e4m3fnuz, float>();
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e4m3fn, float>();
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e5m2, float>();
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e4m3, float>();
success &= RegisterOneWayCustomCast<int2, int4, int8_t>();
success &= RegisterOneWayCustomCast<uint2, uint4, uint8_t>();
return success;
Expand Down Expand Up @@ -372,6 +396,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
return nullptr;
}

if (PyObject_SetAttrString(m.get(), "float8_e3m4",
reinterpret_cast<PyObject*>(
TypeDescriptor<float8_e3m4>::type_ptr)) < 0) {
return nullptr;
}
if (PyObject_SetAttrString(m.get(), "float8_e4m3",
reinterpret_cast<PyObject*>(
TypeDescriptor<float8_e4m3>::type_ptr)) < 0) {
Expand Down
Loading

0 comments on commit 82f3a61

Please sign in to comment.