Skip to content

Commit

Permalink
Add sub-byte data types: float4_e2m1fn, float6_e2m3fn, float6_e3m2fn
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Sep 10, 2024
1 parent 82f3a61 commit 9bdf962
Show file tree
Hide file tree
Showing 11 changed files with 1,137 additions and 144 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Added new 8-bit float types following IEEE 754 convention:
`ml_dtypes.float8_e4m3` and `ml_dtypes.float8_e3m4`.
* Added new 4-bit and 6-bit float types:
`ml_dtypes.float4_e2m1fn`, `ml_dtypes.float6_e2m3fn` and `ml_dtypes.float6_e3m2fn`.
* Fix outputs of float `divmod` and `floor_divide` when denominator is zero.

## [0.4.0] - 2024-04-1
Expand Down
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
* `float8_e4m3fnuz`
* `float8_e5m2`
* `float8_e5m2fnuz`
- Microscaling (MX) sub-byte floating point representations including:
* `float4_e2m1fn`
* `float6_e2m3fn`
* `float6_e3m2fn`
- `int2`, `int4`, `uint2` and `uint4`: low precision integer types.

See below for specifications of these number formats.
Expand Down Expand Up @@ -66,6 +70,39 @@ A `bfloat16` number is a single-precision float truncated at 16 bits.

Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.

### `float4_e2m1`

Exponent: 2, Mantissa: 1, bias: 1.

Extended range: no inf, no NaN.

Microscaling format, 4 bits (encoding: `0bSEEM`) using byte storage (higher 4
bits are unused). NaN representation is undefined.

Possible values: [0, 0.5, 1, 1.5, 2, 3, 4, 6]

### `float6_e2m3`

Exponent: 2, Mantissa: 3, bias: 1.

Extended range: no inf, no NaN.

Microscaling format, 6 bits (encoding: `0bSEEMMM`) using byte storage (higher 2
bits are unused). NaN representation is undefined.

Possible values range: [-7.5; 7.5]

### `float6_e3m2`

Exponent: 3, Mantissa: 2, bias: 3.

Extended range: no inf, no NaN.

Microscaling format, 4 bits (encoding: `0bSEEEMM`) using byte storage (higher 2
bits are unused). NaN representation is undefined.

Possible values range: [-28; 28]

### `float8_e3m4`

Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf.
Expand Down
9 changes: 9 additions & 0 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
"__version__",
"bfloat16",
"finfo",
"float4_e2m1fn",
"float6_e2m3fn",
"float6_e3m2fn",
"float8_e3m4",
"float8_e4m3",
"float8_e4m3b11fnuz",
Expand All @@ -36,6 +39,9 @@
from ml_dtypes._finfo import finfo
from ml_dtypes._iinfo import iinfo
from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float4_e2m1fn
from ml_dtypes._ml_dtypes_ext import float6_e2m3fn
from ml_dtypes._ml_dtypes_ext import float6_e3m2fn
from ml_dtypes._ml_dtypes_ext import float8_e3m4
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
Expand All @@ -50,6 +56,9 @@
import numpy as np

bfloat16: Type[np.generic]
float4_e2m1fn: Type[np.generic]
float6_e2m3fn: Type[np.generic]
float6_e3m2fn: Type[np.generic]
float8_e3m4: Type[np.generic]
float8_e4m3: Type[np.generic]
float8_e4m3b11fnuz: Type[np.generic]
Expand Down
245 changes: 178 additions & 67 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from typing import Dict

from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float4_e2m1fn
from ml_dtypes._ml_dtypes_ext import float6_e2m3fn
from ml_dtypes._ml_dtypes_ext import float6_e3m2fn
from ml_dtypes._ml_dtypes_ext import float8_e3m4
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
Expand All @@ -27,6 +30,9 @@
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
_float6_e2m3fn_dtype = np.dtype(float6_e2m3fn)
_float6_e3m2fn_dtype = np.dtype(float6_e3m2fn)
_float8_e3m4_dtype = np.dtype(float8_e3m4)
_float8_e4m3_dtype = np.dtype(float8_e4m3)
_float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz)
Expand All @@ -45,6 +51,33 @@ def __init__(self):
self.smallest_subnormal = bfloat16(smallest_subnormal)


class _Float4E2m1fnMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p0")
self.smallest_normal = float4_e2m1fn(smallest_normal)
smallest_subnormal = float.fromhex("0x0.8p0")
self.smallest_subnormal = float4_e2m1fn(smallest_subnormal)


class _Float6E2m3fnMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p0")
self.smallest_normal = float6_e2m3fn(smallest_normal)
smallest_subnormal = float.fromhex("0x0.2p0")
self.smallest_subnormal = float6_e2m3fn(smallest_subnormal)


class _Float6E3m2fnMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p-2")
self.smallest_normal = float6_e3m2fn(smallest_normal)
smallest_subnormal = float.fromhex("0x0.4p-2")
self.smallest_subnormal = float6_e3m2fn(smallest_subnormal)


class _Float8E3m4MachArLike:

def __init__(self):
Expand Down Expand Up @@ -110,7 +143,7 @@ def __init__(self):

class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[np.dtype, np.finfo] = {}
_finfo_cache: Dict[type, np.finfo] = {}

@staticmethod
def _bfloat16_finfo():
Expand Down Expand Up @@ -157,6 +190,120 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float4_e2m1fn_finfo():
obj = object.__new__(np.finfo)
obj.dtype = _float4_e2m1fn_dtype
obj.bits = 4
obj.eps = 0.5
obj.epsneg = 0.5
obj.machep = -1
obj.negep = -1
obj.max = float4_e2m1fn(6.0)
obj.min = float4_e2m1fn(-6.0)
obj.nexp = 2
obj.nmant = 1
obj.iexp = obj.nexp
obj.maxexp = 3
obj.minexp = 0
obj.precision = 0
obj.resolution = float4_e2m1fn(1.0)
# pylint: disable=protected-access
obj._machar = _Float4E2m1fnMachArLike()
tiny = obj._machar.smallest_normal
if not hasattr(obj, "tiny"):
obj.tiny = tiny
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = tiny
obj.smallest_subnormal = obj._machar.smallest_subnormal

float_to_str = str
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(obj.max)
obj._str_epsneg = float_to_str(obj.epsneg)
obj._str_eps = float_to_str(obj.eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float6_e2m3fn_finfo():
obj = object.__new__(np.finfo)
obj.dtype = _float6_e2m3fn_dtype
obj.bits = 6
obj.eps = 0.125
obj.epsneg = 0.125
obj.machep = -3
obj.negep = -3
obj.max = float6_e2m3fn(7.5)
obj.min = float6_e2m3fn(-7.5)
obj.nexp = 2
obj.nmant = 3
obj.iexp = obj.nexp
obj.maxexp = 3
obj.minexp = 0
obj.precision = 0
obj.resolution = float4_e2m1fn(1.0)
# pylint: disable=protected-access
obj._machar = _Float6E2m3fnMachArLike()
tiny = obj._machar.smallest_normal
if not hasattr(obj, "tiny"):
obj.tiny = tiny
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = tiny
obj.smallest_subnormal = obj._machar.smallest_subnormal

float_to_str = str
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(obj.max)
obj._str_epsneg = float_to_str(obj.epsneg)
obj._str_eps = float_to_str(obj.eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float6_e3m2fn_finfo():
obj = object.__new__(np.finfo)
obj.dtype = _float6_e3m2fn_dtype
obj.bits = 6
obj.eps = 0.25
obj.epsneg = 0.125
obj.machep = -2
obj.negep = -3
obj.max = float6_e3m2fn(28.0)
obj.min = float6_e3m2fn(-28.0)
obj.nexp = 3
obj.nmant = 2
obj.iexp = obj.nexp
obj.maxexp = 5
obj.minexp = -2
obj.precision = 0
obj.resolution = float6_e3m2fn(1.0)
# pylint: disable=protected-access
obj._machar = _Float6E3m2fnMachArLike()
tiny = obj._machar.smallest_normal
if not hasattr(obj, "tiny"):
obj.tiny = tiny
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = tiny
obj.smallest_subnormal = obj._machar.smallest_subnormal

float_to_str = str
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(obj.max)
obj._str_epsneg = float_to_str(obj.epsneg)
obj._str_eps = float_to_str(obj.eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e3m4_finfo():
def float_to_str(f):
Expand Down Expand Up @@ -472,71 +619,35 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

_finfo_type_map = {
bfloat16: _bfloat16_finfo,
float4_e2m1fn: _float4_e2m1fn_finfo,
float6_e2m3fn: _float6_e2m3fn_finfo,
float6_e3m2fn: _float6_e3m2fn_finfo,
float8_e3m4: _float8_e3m4_finfo,
float8_e4m3: _float8_e4m3_finfo,
float8_e4m3fn: _float8_e4m3fn_finfo,
float8_e4m3fnuz: _float8_e4m3fnuz_finfo,
float8_e4m3b11fnuz: _float8_e4m3b11fnuz_finfo,
float8_e5m2: _float8_e5m2_finfo,
float8_e5m2fnuz: _float8_e5m2fnuz_finfo,
}
_finfo_name_map = {t.__name__: t for t in _finfo_type_map}

def __new__(cls, dtype):
if (
isinstance(dtype, str)
and dtype == "bfloat16"
or dtype == _bfloat16_dtype
):
if _bfloat16_dtype not in cls._finfo_cache:
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
return cls._finfo_cache[_bfloat16_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e3m4"
or dtype == _float8_e3m4_dtype
):
if _float8_e3m4_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e3m4_dtype] = cls._float8_e3m4_finfo()
return cls._finfo_cache[_float8_e3m4_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3"
or dtype == _float8_e4m3_dtype
):
if _float8_e4m3_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3_dtype] = cls._float8_e4m3_finfo()
return cls._finfo_cache[_float8_e4m3_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3b11fnuz"
or dtype == _float8_e4m3b11fnuz_dtype
):
if _float8_e4m3b11fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3b11fnuz_dtype] = (
cls._float8_e4m3b11fnuz_finfo()
)
return cls._finfo_cache[_float8_e4m3b11fnuz_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3fn"
or dtype == _float8_e4m3fn_dtype
):
if _float8_e4m3fn_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo()
return cls._finfo_cache[_float8_e4m3fn_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3fnuz"
or dtype == _float8_e4m3fnuz_dtype
):
if _float8_e4m3fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3fnuz_dtype] = cls._float8_e4m3fnuz_finfo()
return cls._finfo_cache[_float8_e4m3fnuz_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e5m2"
or dtype == _float8_e5m2_dtype
):
if _float8_e5m2_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo()
return cls._finfo_cache[_float8_e5m2_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e5m2fnuz"
or dtype == _float8_e5m2fnuz_dtype
):
if _float8_e5m2fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2fnuz_dtype] = cls._float8_e5m2fnuz_finfo()
return cls._finfo_cache[_float8_e5m2fnuz_dtype]
key = (
cls._finfo_name_map.get(dtype)
if isinstance(dtype, str)
else dtype.type
if isinstance(dtype, np.dtype)
else dtype
)
finfo = cls._finfo_cache.get(key)
if finfo is not None:
return finfo

init = cls._finfo_type_map.get(key)
if init is not None:
cls._finfo_cache[dtype] = init.__func__()
return cls._finfo_cache[dtype]
return super().__new__(cls, dtype)
Loading

0 comments on commit 9bdf962

Please sign in to comment.