From 40bcd6f1c6dc11e7f1c6ba66c074ae97876cc39e Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Wed, 28 Aug 2024 23:09:37 -0700 Subject: [PATCH] Add f8e5m2 support --- test/test_fp8.py | 30 ++++++++++++++---- torch_xla/csrc/tensor_util.cpp | 57 ++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 6 deletions(-) diff --git a/test/test_fp8.py b/test/test_fp8.py index af8086ff1086..dbd5c37bfb69 100644 --- a/test/test_fp8.py +++ b/test/test_fp8.py @@ -1,17 +1,24 @@ -import sys +import os +import re import torch import torch_xla import unittest +from absl.testing import parameterized import torch_xla.core.xla_model as xm device = xm.xla_device() +dtype_parameters = [ + torch.float8_e5m2, + torch.float8_e4m3fn, +] -class Fp8Test(unittest.TestCase): - def test_fp8(self): - dtype = torch.float8_e4m3fn +class Fp8Test(parameterized.TestCase): + + @parameterized.parameters(*dtype_parameters) + def test_fp8(self, dtype): t = torch.rand(2, 2).to(dtype) xla_t = t.to(device) torch_t = xla_t.cpu() @@ -21,8 +28,8 @@ def test_fp8(self): self.assertTrue( torch.allclose(t.to(torch.float32), torch_t.to(torch.float32))) - def test_fp8_matmul(self): - dtype = torch.float8_e4m3fn + @parameterized.parameters(*dtype_parameters) + def test_fp8_matmul(self, dtype): t = torch.rand(3, 2).to(dtype) w = torch.rand(2, 5).to(dtype) torch_matmul = torch.matmul(t, w) @@ -35,6 +42,17 @@ def test_fp8_matmul(self): torch.allclose( xla_matmul.to(torch.float32), torch_matmul.to(torch.float32))) + @parameterized.parameters(*dtype_parameters) + def test_fp8_hlo(self, dtype): + x = torch.randn((3, 5)).to(dtype).to(device) + w = torch.randn((5, 8)).to(dtype).to(device) + output = torch.matmul(x, w) + hlo = torch_xla._XLAC._get_xla_tensors_hlo([output]) + exmy_str = str(dtype).split('_')[-1] + self.assertTrue( + re.search(rf'f8{exmy_str}.*dot.*f8{exmy_str}.*f8{exmy_str}', hlo) + is not None) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 6fe883cc36b4..3731c065b759 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -145,6 +145,8 @@ xla::PrimitiveType XlaTypeFromTensorType( return xla::PrimitiveType::C128; case at::ScalarType::Float8_e4m3fn: return xla::PrimitiveType::F8E4M3FN; + case at::ScalarType::Float8_e5m2: + return xla::PrimitiveType::F8E5M2; default: XLA_ERROR() << "Type not supported: " << scalar_type; } @@ -172,6 +174,21 @@ struct Caster { } }; template <> +struct Caster { + template + D cast(const at::Float8_e5m2& value) const { + return static_cast(static_cast(value)); + } +}; + +template <> +struct Caster { + template + D cast(const tsl::float8_e5m2& value) const { + return static_cast(static_cast(value)); + } +}; +template <> struct Caster { template D cast(const at::Float8_e4m3fn& value) const { @@ -309,6 +326,14 @@ struct NeedCast { static constexpr bool value = true; }; template <> +struct NeedCast { + static constexpr bool value = true; +}; +template <> +struct NeedCast { + static constexpr bool value = true; +}; +template <> struct NeedCast { static constexpr bool value = true; }; @@ -381,6 +406,18 @@ void CopyData(tsl::bfloat16* dest, CheckedMemcpy(dest, source, n); } template <> +void CopyData(at::Float8_e5m2* dest, + const tsl::float8_e5m2* source, + int64_t n, const CopyCasted&) { + CheckedMemcpy(dest, source, n); +} +template <> +void CopyData(tsl::float8_e5m2* dest, + const at::Float8_e5m2* source, + int64_t n, const CopyCasted&) { + CheckedMemcpy(dest, source, n); +} +template <> void CopyData( at::Float8_e4m3fn* dest, const tsl::float8_e4m3fn* source, int64_t n, const CopyCasted&) { @@ -605,6 +642,10 @@ void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape, TensorToBuffer(tensor, dest_shape, dest_buffer, dest_buffer_size, device); break; + case xla::PrimitiveType::F8E5M2: + TensorToBuffer(tensor, dest_shape, dest_buffer, + dest_buffer_size, device); + break; default: XLA_ERROR() << "Destination shape type not supported: " << dest_shape; } @@ -756,6 +797,9 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal, case at::ScalarType::Float8_e4m3fn: return XlaLiteralToTensor(literal, dest_element_type); + case at::ScalarType::Float8_e5m2: + return XlaLiteralToTensor(literal, + dest_element_type); default: XLA_ERROR() << "Unsupported scalar type: " << dest_element_type; @@ -821,6 +865,10 @@ void PopulateTensorBuffer(const at::Tensor& tensor, TensorToBufferSType(tensor, dest_shape, dest_buffer, dest_buffer_size, device); break; + case at::ScalarType::Float8_e5m2: + TensorToBufferSType(tensor, dest_shape, dest_buffer, + dest_buffer_size, device); + break; default: XLA_ERROR() << "Tensor type not supported: " << tensor.type(); } @@ -875,6 +923,9 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal, case xla::PrimitiveType::F8E4M3FN: return XlaLiteralToTensorHelper(literal, dest_element_type); + case xla::PrimitiveType::F8E5M2: + return XlaLiteralToTensorHelper(literal, + dest_element_type); default: XLA_ERROR() << "Unsupported literal type: " << literal.shape(); } @@ -1159,6 +1210,8 @@ at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type) { return at::ScalarType::ComplexDouble; case xla::PrimitiveType::F8E4M3FN: return at::ScalarType::Float8_e4m3fn; + case xla::PrimitiveType::F8E5M2: + return at::ScalarType::Float8_e5m2; default: XLA_ERROR() << "XLA type not supported: " << xla_type; } @@ -1192,6 +1245,8 @@ xla::PrimitiveType TensorTypeToRawXlaType(at::ScalarType scalar_type) { return xla::PrimitiveType::C128; case at::ScalarType::Float8_e4m3fn: return xla::PrimitiveType::F8E4M3FN; + case at::ScalarType::Float8_e5m2: + return xla::PrimitiveType::F8E5M2; default: XLA_ERROR() << "Type not supported: " << scalar_type; } @@ -1270,6 +1325,8 @@ xla::PrimitiveType MakeXlaPrimitiveType( return GetDevicePrimitiveType(xla::PrimitiveType::C128, device); case at::ScalarType::Float8_e4m3fn: return GetDevicePrimitiveType(xla::PrimitiveType::F8E4M3FN, device); + case at::ScalarType::Float8_e5m2: + return GetDevicePrimitiveType(xla::PrimitiveType::F8E5M2, device); default: XLA_ERROR() << "Type not supported: " << scalar_type; }