Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add f8e5m2 support #7924

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions test/test_fp8.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import sys
import os
import re

import torch
import torch_xla
import unittest
from absl.testing import parameterized
import torch_xla.core.xla_model as xm

device = xm.xla_device()

dtype_parameters = [
torch.float8_e5m2,
torch.float8_e4m3fn,
]

class Fp8Test(unittest.TestCase):

def test_fp8(self):
dtype = torch.float8_e4m3fn
class Fp8Test(parameterized.TestCase):

@parameterized.parameters(*dtype_parameters)
def test_fp8(self, dtype):
t = torch.rand(2, 2).to(dtype)
xla_t = t.to(device)
torch_t = xla_t.cpu()
Expand All @@ -21,8 +28,8 @@ def test_fp8(self):
self.assertTrue(
torch.allclose(t.to(torch.float32), torch_t.to(torch.float32)))

def test_fp8_matmul(self):
dtype = torch.float8_e4m3fn
@parameterized.parameters(*dtype_parameters)
def test_fp8_matmul(self, dtype):
t = torch.rand(3, 2).to(dtype)
w = torch.rand(2, 5).to(dtype)
torch_matmul = torch.matmul(t, w)
Expand All @@ -35,6 +42,17 @@ def test_fp8_matmul(self):
torch.allclose(
xla_matmul.to(torch.float32), torch_matmul.to(torch.float32)))

@parameterized.parameters(*dtype_parameters)
def test_fp8_hlo(self, dtype):
x = torch.randn((3, 5)).to(dtype).to(device)
w = torch.randn((5, 8)).to(dtype).to(device)
output = torch.matmul(x, w)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
exmy_str = str(dtype).split('_')[-1]
self.assertTrue(
re.search(rf'f8{exmy_str}.*dot.*f8{exmy_str}.*f8{exmy_str}', hlo)
is not None)


if __name__ == '__main__':
test = unittest.main()
Expand Down
57 changes: 57 additions & 0 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ xla::PrimitiveType XlaTypeFromTensorType(
return xla::PrimitiveType::C128;
case at::ScalarType::Float8_e4m3fn:
return xla::PrimitiveType::F8E4M3FN;
case at::ScalarType::Float8_e5m2:
return xla::PrimitiveType::F8E5M2;
default:
XLA_ERROR() << "Type not supported: " << scalar_type;
}
Expand Down Expand Up @@ -172,6 +174,21 @@ struct Caster<tsl::bfloat16> {
}
};
template <>
struct Caster<at::Float8_e5m2> {
template <typename D>
D cast(const at::Float8_e5m2& value) const {
return static_cast<D>(static_cast<float>(value));
}
};

template <>
struct Caster<tsl::float8_e5m2> {
template <typename D>
D cast(const tsl::float8_e5m2& value) const {
return static_cast<D>(static_cast<float>(value));
}
};
template <>
struct Caster<at::Float8_e4m3fn> {
template <typename D>
D cast(const at::Float8_e4m3fn& value) const {
Expand Down Expand Up @@ -309,6 +326,14 @@ struct NeedCast<at::BFloat16> {
static constexpr bool value = true;
};
template <>
struct NeedCast<tsl::float8_e5m2> {
static constexpr bool value = true;
};
template <>
struct NeedCast<at::Float8_e5m2> {
static constexpr bool value = true;
};
template <>
struct NeedCast<tsl::float8_e4m3fn> {
static constexpr bool value = true;
};
Expand Down Expand Up @@ -381,6 +406,18 @@ void CopyData<tsl::bfloat16, at::BFloat16>(tsl::bfloat16* dest,
CheckedMemcpy<tsl::bfloat16, at::BFloat16>(dest, source, n);
}
template <>
void CopyData<at::Float8_e5m2, tsl::float8_e5m2>(at::Float8_e5m2* dest,
const tsl::float8_e5m2* source,
int64_t n, const CopyCasted&) {
CheckedMemcpy<at::Float8_e5m2, tsl::float8_e5m2>(dest, source, n);
}
template <>
void CopyData<tsl::float8_e5m2, at::Float8_e5m2>(tsl::float8_e5m2* dest,
const at::Float8_e5m2* source,
int64_t n, const CopyCasted&) {
CheckedMemcpy<tsl::float8_e5m2, at::Float8_e5m2>(dest, source, n);
}
template <>
void CopyData<at::Float8_e4m3fn, tsl::float8_e4m3fn>(
at::Float8_e4m3fn* dest, const tsl::float8_e4m3fn* source, int64_t n,
const CopyCasted&) {
Expand Down Expand Up @@ -605,6 +642,10 @@ void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape,
TensorToBuffer<SType, tsl::float8_e4m3fn>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::F8E5M2:
TensorToBuffer<SType, tsl::float8_e5m2>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
default:
XLA_ERROR() << "Destination shape type not supported: " << dest_shape;
}
Expand Down Expand Up @@ -756,6 +797,9 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,
case at::ScalarType::Float8_e4m3fn:
return XlaLiteralToTensor<SType, at::Float8_e4m3fn>(literal,
dest_element_type);
case at::ScalarType::Float8_e5m2:
return XlaLiteralToTensor<SType, at::Float8_e5m2>(literal,
dest_element_type);

default:
XLA_ERROR() << "Unsupported scalar type: " << dest_element_type;
Expand Down Expand Up @@ -821,6 +865,10 @@ void PopulateTensorBuffer(const at::Tensor& tensor,
TensorToBufferSType<at::Float8_e4m3fn>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Float8_e5m2:
TensorToBufferSType<at::Float8_e5m2>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
default:
XLA_ERROR() << "Tensor type not supported: " << tensor.type();
}
Expand Down Expand Up @@ -875,6 +923,9 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal,
case xla::PrimitiveType::F8E4M3FN:
return XlaLiteralToTensorHelper<tsl::float8_e4m3fn>(literal,
dest_element_type);
case xla::PrimitiveType::F8E5M2:
return XlaLiteralToTensorHelper<tsl::float8_e5m2>(literal,
dest_element_type);
default:
XLA_ERROR() << "Unsupported literal type: " << literal.shape();
}
Expand Down Expand Up @@ -1159,6 +1210,8 @@ at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type) {
return at::ScalarType::ComplexDouble;
case xla::PrimitiveType::F8E4M3FN:
return at::ScalarType::Float8_e4m3fn;
case xla::PrimitiveType::F8E5M2:
return at::ScalarType::Float8_e5m2;
default:
XLA_ERROR() << "XLA type not supported: " << xla_type;
}
Expand Down Expand Up @@ -1192,6 +1245,8 @@ xla::PrimitiveType TensorTypeToRawXlaType(at::ScalarType scalar_type) {
return xla::PrimitiveType::C128;
case at::ScalarType::Float8_e4m3fn:
return xla::PrimitiveType::F8E4M3FN;
case at::ScalarType::Float8_e5m2:
return xla::PrimitiveType::F8E5M2;
default:
XLA_ERROR() << "Type not supported: " << scalar_type;
}
Expand Down Expand Up @@ -1270,6 +1325,8 @@ xla::PrimitiveType MakeXlaPrimitiveType(
return GetDevicePrimitiveType(xla::PrimitiveType::C128, device);
case at::ScalarType::Float8_e4m3fn:
return GetDevicePrimitiveType(xla::PrimitiveType::F8E4M3FN, device);
case at::ScalarType::Float8_e5m2:
return GetDevicePrimitiveType(xla::PrimitiveType::F8E5M2, device);
default:
XLA_ERROR() << "Type not supported: " << scalar_type;
}
Expand Down
Loading