diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 423490dcd0..90c5e499f3 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -22,3 +22,5 @@ pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py +pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py +pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py new file mode 100644 index 0000000000..9aab3b2702 --- /dev/null +++ b/tests/pytorch/test_fusible_ops.py @@ -0,0 +1,953 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import math + +import pytest +import torch + +import transformer_engine +import transformer_engine.pytorch as te +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.ops._common import is_float8_tensor +from transformer_engine.pytorch.ops.fused_forward import ( + ForwardLinearBiasActivation, +) +from transformer_engine.pytorch.utils import is_bf16_compatible +import transformer_engine_torch as tex + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +# Supported data types +_dtypes: list[torch.dtype] = [torch.float32, torch.float16] +if is_bf16_compatible(): # bf16 requires sm_80 or higher + _dtypes.append(torch.bfloat16) + +# Supported devices +_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] + + +def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: + """Estimated numerical error for a datatype + + Based on tolerances for torch.testing.assert_close. + + """ + + # Transformer Engine dtypes + if isinstance(dtype, tex.DType): + if dtype == tex.DType.kFloat8E4M3: + return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 + if dtype == tex.DType.kFloat8E5M2: + return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 + dtype = { + tex.DType.kByte: torch.uint8, + tex.DType.kInt32: torch.int32, + tex.DType.kFloat32: torch.float32, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, + }[dtype] + + # PyTorch dtypes + if dtype == torch.float16: + return dict(rtol=1e-3, atol=1e-5) + if dtype == torch.bfloat16: + return dict(rtol=1.6e-2, atol=1e-5) + if dtype == torch.float32: + return dict(rtol=1.3e-6, atol=1e-5) + if dtype == torch.float64: + return dict(rtol=1e-7, atol=1e-7) + raise ValueError(f"Unsupported dtype ({dtype})") + + +@torch.no_grad() +def make_reference_and_test_tensors( + shape: int | Iterable[int], + ref_dtype: torch.dtype = torch.float64, + ref_device: torch.device = "cpu", + test_dtype: torch.dtype = torch.float32, + test_device: torch.device = "cuda", + test_is_fp8: bool = False, + requires_grad: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """Construct tensors with the same values + + The reference tensor is intended for use in plain PyTorch + operations in high precision. The test tensor is intended for use + in Transformer Engine operations. + + """ + ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + if test_is_fp8: + test = Float8Tensor.to_float8(ref) + test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1) + test._transpose = test._transpose.contiguous() + test._transpose_invalid = False + else: + test = ref.to(device=test_device, dtype=test_dtype) + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + ref.copy_(test) + ref.requires_grad_(requires_grad) + test.requires_grad_(requires_grad) + return ref, test + + +class TestSequential: + """Tests for sequential container""" + + def test_modules(self) -> None: + """Check that list of modules can be manipulated as expected""" + + # Construct sequential container + modules = [ + te_ops.Identity(), + te_ops.Identity(), + torch.nn.Identity(), + te_ops.Identity(), + ] + model = te_ops.Sequential(*modules) + + # Length + assert len(model) == len(modules) + + # Iterator + for module1, module2 in zip(model, modules): + assert module1 is module2 + + # Index by int + for i, module in enumerate(modules): + assert model[i] is module + assert model[i - len(modules)] is module + + # Index by slice + model_subset = model[1:-1] + modules_subset = modules[1:-1] + assert isinstance(model_subset, te_ops.Sequential) + for module1, module2 in zip(model_subset, modules_subset): + assert module1 is module2 + + # Set element + new_module = torch.nn.Identity() + idx = 1 + modules[idx] = new_module + model[idx] = new_module + for module1, module2 in zip(model, modules): + assert module1 is module2 + + # Delete element + idx = 1 + del modules[idx] + del model[idx] + for module1, module2 in zip(model, modules): + assert module1 is module2 + + # Append + new_module = torch.nn.Identity() + modules.append(new_module) + model.append(new_module) + for module1, module2 in zip(model, modules): + assert module1 is module2 + + # Extend + new_modules = [te_ops.Identity(), te_ops.Identity()] + modules.extend(new_modules) + model.extend(new_modules) + for module1, module2 in zip(model, modules): + assert module1 is module2 + + # Insert + new_module = te_ops.Identity() + idx = 2 + modules.insert(idx, new_module) + model.insert(idx, new_module) + for module1, module2 in zip(model, modules): + assert module1 is module2 + + # Pop + idx = 2 + assert model.pop(idx) is modules.pop(idx) + for module1, module2 in zip(model, modules): + assert module1 is module2 + + # Out-of-place add + new_modules = [torch.nn.Identity(), te_ops.Identity()] + added_modules = modules + new_modules + added_model = model + te_ops.Sequential(*new_modules) + for module1, module2 in zip(model, modules): + assert module1 is module2 + for module1, module2 in zip(added_model, added_modules): + assert module1 is module2 + + # In-place add + new_modules = [te_ops.Identity(), torch.nn.Identity()] + modules += new_modules + model += te_ops.Sequential(*new_modules) + for module1, module2 in zip(model, modules): + assert module1 is module2 + + def test_module_groups(self) -> None: + """Check that modules are grouped together correctly""" + model = te_ops.Sequential( + te_ops.Identity(), + te_ops.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + te_ops.Identity(), + torch.nn.Identity(), + te_ops.Identity(), + te_ops.Identity(), + te_ops.Identity(), + ) + model(torch.zeros(1)) + assert len(model._module_groups) == 6 + + +class TestFuser: + """Tests for operation fusion infrastructure""" + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + def test_fp8_scale_update( + self, + size: int = 16, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ): + """Test FP8 scaling factors with delayed scaling recipe""" + + # FP8 recipe + margin = 2 + fp8_format = transformer_engine.common.recipe.Format.HYBRID + recipe = transformer_engine.common.recipe.DelayedScaling( + margin=margin, + interval=1, + fp8_format=fp8_format, + amax_history_len=8, + amax_compute_algo="max", + ) + + # Construct model + with te.fp8_model_init(): + model = te_ops.basic.BasicLinear( + size, + size, + device=device, + dtype=dtype, + ) + + # Training steps + w_vals = [2, 5, 3, 11] + x_vals = [7, 3, 5] + dy_vals = [1, 2, 1] + with torch.no_grad(): + model.weight.fill_(w_vals[0]) + for step in range(3): + + # Data tensors + x = torch.full( + (size, size), + x_vals[step], + dtype=dtype, + device=device, + requires_grad=True, + ) + dy = torch.full( + (size, size), + dy_vals[step], + dtype=dtype, + device=device, + ) + + # Training step + with te.fp8_autocast(fp8_recipe=recipe): + y = model(x) + y.backward(dy) + with torch.no_grad(): + model.weight.fill_(w_vals[step + 1]) + + # Check that output tensors match expected + tols = dict(rtol=0, atol=0) + y_val_ref = w_vals[step] * x_vals[step] * size + dx_val_ref = w_vals[step] * dy_vals[step] * size + torch.testing.assert_close( + y, + torch.full_like(y, y_val_ref), + **dtype_tols(tex.DType.kFloat8E4M3), + ) + torch.testing.assert_close( + x.grad, + torch.full_like(x.grad, dx_val_ref), + **dtype_tols(tex.DType.kFloat8E5M2), + ) + + # Check that scaling factors match expected + w_amax_ref = max(w_vals[: step + 2]) + x_amax_ref = max(x_vals[: step + 1]) + dy_amax_ref = max(dy_vals[: step + 1]) + w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin) + x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin) + dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin) + forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + w_scale = model.get_fp8_meta("param")[forward_key].scale + x_scale = model.get_fp8_meta("input")[forward_key].scale + dy_scale = model.get_fp8_meta("grad_output")[backward_key].scale + torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref)) + torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref)) + torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref)) + + +class TestBasicOps: + """Tests for individual operations""" + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("in_shape", ((1,),)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("device", ("cuda", "cpu")) + @pytest.mark.parametrize("fp8", (False, True)) + def test_identity( + self, + *, + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device, + fp8: bool, + ) -> None: + + # Skip invalid configurations + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8 and torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = x_ref + dx_ref = dy_ref + + # Implementation with fusible operation + op = te_ops.Identity() + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + tols = dict(rtol=0, atol=0) # Identity is exact + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(y_test, -y_ref, **tols) + with pytest.raises(AssertionError): + torch.testing.assert_close(dx_test, -dx_ref, **tols) + + @pytest.mark.parametrize( + "shapes", + ( + ((1, 2, 3, 4), (2, 12)), + ((5, 4, 3, 2), (-1, 6)), + ((30,), (2, 3, -1)), + ((6, 7), (3, -1, 7)), + ), + ) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("device", ("cuda", "cpu")) + @pytest.mark.parametrize( + "memory_format", + (torch.contiguous_format, torch.channels_last), + ) + @pytest.mark.parametrize("fp8", (False, True)) + def test_reshape( + self, + *, + shapes: tuple[Iterable[int], Iterable[int]], + dtype: torch.dtype, + device: torch.device, + memory_format: torch.memory_format, + fp8: bool, + ) -> None: + in_shape, out_shape = shapes + + # Skip invalid configurations + if memory_format == torch.channels_last and len(in_shape) != 4: + pytest.skip("torch.channels_last only supports 4D tensors") + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8 and torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + x_test = x_test.contiguous(memory_format=memory_format) + x_test = x_test.detach().requires_grad_() + dy_ref, dy_test = make_reference_and_test_tensors( + x_ref.reshape(out_shape).size(), + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = x_ref.reshape(out_shape) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.Reshape(out_shape) + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + tols = dict(rtol=0, atol=0) # Reshape is exact + y_test = y_test.to( + dtype=torch.float64, + device="cpu", + memory_format=torch.contiguous_format, + ) + dx_test = x_test.grad.to( + dtype=torch.float64, + device="cpu", + memory_format=torch.contiguous_format, + ) + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + @pytest.mark.parametrize("size", (1, 7, 32)) + @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("device", _devices) + @pytest.mark.parametrize("fp8", (False, True)) + def test_bias( + self, + *, + size: int, + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device, + fp8: bool, + ) -> None: + + # Make input and bias shapes consistent + in_shape = list(in_shape)[:-1] + [size] + + # Skip invalid configurations + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8 and torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + b_ref, b_test = make_reference_and_test_tensors( + size, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [size]) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.Bias(size, device=device, dtype=dtype) + with torch.no_grad(): + op.bias.copy_(b_test) + del b_test + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(db_test, b_ref.grad, **tols) + + @pytest.mark.parametrize("weight_shape", ((48, 16), (3, 5))) + @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (2, 2, 4, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("fp8_grad_output", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + def test_basic_linear( + self, + *, + weight_shape: tuple[int, int], + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_compute: bool, + fp8_input: bool, + fp8_weight: bool, + fp8_grad_output: bool, + accumulate_into_main_grad: bool, + ) -> None: + """GEMM""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + if fp8_compute or fp8_input or fp8_weight or fp8_grad_output: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + if fp8_compute: + if ( + math.prod(in_shape[:-1]) % 16 != 0 + or in_features % 16 != 0 + or out_features % 16 != 0 + ): + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_input), + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_weight), + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_grad_output), + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + with te.fp8_model_init(enabled=fp8_weight): + op = te_ops.BasicLinear( + in_features, + out_features, + device=device, + dtype=dtype, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + with torch.no_grad(): + op.weight.copy_(w_test) + del w_test + op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32) + with te.fp8_autocast(enabled=fp8_compute): + y_test = op(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 + ) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + if accumulate_into_main_grad: + if op.weight.grad is not None: + torch.testing.assert_close( + op.weight.grad, + torch.zeros_like(op.weight.grad), + rtol=0, + atol=0, + ) + dw_test = op.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5 + else: + dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close( + op.weight.main_grad, + torch.full_like(op.weight.main_grad, 0.5), + rtol=0, + atol=0, + ) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("fp8_weight", (False, True)) + def test_linear( + self, + *, + bias: bool, + weight_shape: tuple[int, int] = (16, 16), + in_shape: Iterable[int] = (16, -1), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + fp8_compute: bool, + fp8_input: bool = False, + fp8_weight: bool, + ) -> None: + """GEMM + bias""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + if fp8_input or fp8_weight or fp8_compute: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + if fp8_compute: + if ( + math.prod(in_shape[:-1]) % 16 != 0 + or in_features % 16 != 0 + or out_features % 16 != 0 + ): + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_input), + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_weight), + ) + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + with te.fp8_model_init(enabled=fp8_weight): + op = te_ops.Linear( + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + op.weight.copy_(w_test) + if bias: + op.bias.copy_(b_test) + del w_test + del b_test + with te.fp8_autocast(enabled=fp8_compute): + y_test = op(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 + ) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + if bias: + db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, b_ref.grad, **tols) + + +class TestFusedOps: + """Tests for fused operations""" + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("weight_shape", ((32, 48), (3, 5))) + @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (4, 2, 10, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_weight", (False, True)) + def test_linear_bias_activation( + self, + *, + bias: bool = True, + weight_shape: tuple[int, int], + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_compute: bool, + fp8_input: bool, + fp8_weight: bool, + ) -> None: + """GEMM + bias + activation""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + if fp8_input or fp8_weight or fp8_compute: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + if fp8_compute: + if ( + math.prod(in_shape[:-1]) % 16 != 0 + or in_features % 16 != 0 + or out_features % 16 != 0 + ): + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + if dtype not in (torch.float16, torch.bfloat16): + pytest.skip( + "FP8 fused linear-bias-activation is only supported with FP16 or BF16 output" + ) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_input), + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_weight), + ) + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) + y_ref.backward(dy_ref) + + # Implementation with fusible operations + with te.fp8_model_init(enabled=fp8_weight): + model = te_ops.Sequential( + te_ops.Linear( + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + ), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + if bias: + model[0].bias.copy_(b_test) + del w_test + del b_test + with te.fp8_autocast(enabled=fp8_compute): + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance(forward_ops[0][0], ForwardLinearBiasActivation) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + model[0].weight._fp8_dtype + if is_float8_tensor(model[0].weight) + else tex.DType.kFloat8E4M3 + ) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + if bias: + db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, b_ref.grad, **tols) + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + def test_fp8_linear( + self, + *, + in_shape: Iterable[int] = (16, 16), + dtype: torch.dtype = torch.bfloat16, + device: torch.device = "cuda", + ) -> None: + """Adjacent linear ops with FP8 enabled""" + + # Make input and weight shapes consistent + in_shape = tuple(in_shape) + weight_shape = (in_shape[-1], in_shape[-1]) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=True, + ) + w0_ref, w0_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=True, + ) + w1_ref, w1_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=True, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w0_ref) + y_ref = torch.nn.functional.linear(y_ref, w1_ref) + y_ref.backward(dy_ref) + + # Implementation with fusible operations + with te.fp8_model_init(enabled=True): + model = te_ops.Sequential( + te_ops.BasicLinear( + in_shape[-1], + in_shape[-1], + device=device, + dtype=dtype, + ), + te_ops.BasicLinear( + in_shape[-1], + in_shape[-1], + device=device, + dtype=dtype, + ), + ) + with torch.no_grad(): + model[0].weight.copy_(w0_test) + model[1].weight.copy_(w1_test) + del w0_test, w1_test + with te.fp8_autocast(enabled=True): + y_test = model(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(model[0].weight._fp8_dtype) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw0_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw0_test, w0_ref.grad, **tols) + torch.testing.assert_close(dw1_test, w1_ref.grad, **tols) diff --git a/tests/pytorch/test_fusible_ops_distributed.py b/tests/pytorch/test_fusible_ops_distributed.py new file mode 100644 index 0000000000..d8a018761b --- /dev/null +++ b/tests/pytorch/test_fusible_ops_distributed.py @@ -0,0 +1,836 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import argparse +import functools +import itertools +import os +import pathlib +import subprocess +import sys + +import pytest +import torch + +import transformer_engine +import transformer_engine.pytorch as te +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.ops._common import is_float8_tensor +from transformer_engine.pytorch.utils import is_bf16_compatible +import transformer_engine_torch as tex + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +@functools.cache +def world_group() -> torch.distributed.ProcessGroup: + """Get NCCL process group, initializing if needed""" + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(rank) + group = torch.distributed.init_process_group( + "nccl", + init_method="file:///tmp/rdzv", + world_size=world_size, + rank=rank, + ) + return group + + +def reset_rng(seed: int = 1234) -> None: + """Reset random number generators""" + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +@torch.no_grad() +def make_reference_and_test_tensors( + shape: int | Iterable[int], + ref_dtype: torch.dtype = torch.float64, + ref_device: torch.device = "cpu", + test_dtype: torch.dtype = torch.float32, + test_device: torch.device = "cuda", + test_is_fp8: bool = False, + requires_grad: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """Construct tensors with the same values + + The reference tensor is intended for use in plain PyTorch + operations in high precision. The test tensor is intended for use + in Transformer Engine operations. + + """ + + # Random data + ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Make copy of tensor + if test_is_fp8: + test = Float8Tensor.to_float8(ref) + else: + test = ref.to(device=test_device, dtype=test_dtype) + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + + # Make sure reference and test tensors represent exact same values + ref.copy_(test) + + # Return reference and test tensors + ref.requires_grad_(requires_grad) + test.requires_grad_(requires_grad) + return ref, test + + +def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: + """Estimated numerical error for a datatype + + Based on tolerances for torch.testing.assert_close. + + """ + + # Transformer Engine dtypes + if isinstance(dtype, tex.DType): + if dtype == tex.DType.kFloat8E4M3: + return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 + if dtype == tex.DType.kFloat8E5M2: + return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 + dtype = { + tex.DType.kByte: torch.uint8, + tex.DType.kInt32: torch.int32, + tex.DType.kFloat32: torch.float32, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, + }[dtype] + + # PyTorch dtypes + if dtype == torch.float16: + return dict(rtol=1e-3, atol=1e-5) + if dtype == torch.bfloat16: + return dict(rtol=1.6e-2, atol=1e-5) + if dtype == torch.float32: + return dict(rtol=1.3e-6, atol=1e-5) + if dtype == torch.float64: + return dict(rtol=1e-7, atol=1e-7) + raise ValueError(f"Unsupported dtype ({dtype})") + + +def _test_all_reduce( + *, + local_size: int = 17, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + fp8: bool = False, +) -> None: + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + in_shape = [world_size, local_size] + out_shape = [local_size] + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + + # Plain PyTorch implementation + y_ref = x_ref.sum(0) + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + dx_ref = x_ref.grad[rank] + x_ref = x_ref[rank] + x_test = x_test[rank].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + op = te_ops.AllReduce(process_group=process_group) + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **dtype_tols(dtype)) + torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0) + + +def _test_all_gather( + *, + local_size: int = 13, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + fp8: bool = False, +) -> None: + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + in_shape = [world_size, local_size] + out_shape = [world_size, world_size * local_size] + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + + # Plain PyTorch implementation + y_ref = x_ref.tile((world_size, 1)).reshape(out_shape) + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + dx_ref = x_ref.grad[rank] + x_ref = x_ref[rank] + x_test = x_test[rank].clone() + y_ref = y_ref[rank] + dy_ref = dy_ref[rank] + dy_test = dy_test[rank].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + op = te_ops.AllGather(process_group=process_group) + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_test, dx_ref, **dtype_tols(dtype)) + + +def _test_reduce_scatter( + *, + local_size: int = 11, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + fp8: bool = False, +) -> None: + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + in_shape = [world_size, world_size * local_size] + out_shape = [world_size, local_size] + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + + # Plain PyTorch implementation + y_ref = x_ref.sum(0).reshape(out_shape) + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + dx_ref = x_ref.grad[rank] + x_ref = x_ref[rank] + x_test = x_test[rank].clone() + y_ref = y_ref[rank] + dy_ref = dy_ref[rank] + dy_test = dy_test[rank].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + op = te_ops.ReduceScatter(process_group=process_group) + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **dtype_tols(dtype)) + torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0) + + +def _test_basic_linear( + *, + local_weight_shape: tuple[int, int] = (16, 16), + batch_size: int = 16, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + fp8_compute: bool = False, + fp8_input: bool = False, + fp8_weight: bool = False, + fp8_grad_output: bool = False, + tensor_parallel_mode: str = "column", + sequence_parallel: bool = False, +) -> None: + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + local_out_features, local_in_features = local_weight_shape + out_features, in_features = local_out_features, local_in_features + if tensor_parallel_mode == "column": + out_features *= world_size + elif tensor_parallel_mode == "row": + in_features *= world_size + in_shape = [batch_size, in_features] + out_shape = [batch_size, out_features] + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_input), + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_weight), + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_grad_output), + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + dw_ref = w_ref.grad + dx_ref = x_ref.grad + if tensor_parallel_mode == "column": + local_out_features = out_features // world_size + local_slice = slice( + rank * local_out_features, + (rank + 1) * local_out_features, + ) + w_ref = w_ref[local_slice, :] + dw_ref = dw_ref[local_slice, :] + w_test = w_test[local_slice, :] + y_ref = y_ref[..., local_slice] + dy_ref = dy_ref[..., local_slice] + dy_test = dy_test[..., local_slice].clone() + elif tensor_parallel_mode == "row": + local_in_features = in_features // world_size + local_slice = slice( + rank * local_in_features, + (rank + 1) * local_in_features, + ) + w_ref = w_ref[:, local_slice] + dw_ref = dw_ref[:, local_slice] + w_test = w_test[:, local_slice] + x_ref = x_ref[..., local_slice] + dx_ref = dx_ref[..., local_slice] + x_test = x_test[..., local_slice].clone() + if sequence_parallel: + local_batch_size = batch_size // world_size + local_slice = slice( + rank * local_batch_size, + (rank + 1) * local_batch_size, + ) + if tensor_parallel_mode == "column": + x_ref = x_ref[local_slice, ...] + dx_ref = dx_ref[local_slice, ...] + x_test = x_test[local_slice, ...].clone() + elif tensor_parallel_mode == "row": + y_ref = y_ref[local_slice, ...] + dy_ref = dy_ref[local_slice, ...] + dy_test = dy_test[local_slice, ...].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + with te.fp8_model_init(enabled=fp8_weight): + op = te_ops.BasicLinear( + in_features, + out_features, + device=device, + dtype=dtype, + tensor_parallel_mode=tensor_parallel_mode, + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + ) + with torch.no_grad(): + op.weight.copy_(w_test) + del w_test + with te.fp8_autocast(enabled=fp8_compute): + y_test = op(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 + ) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + torch.testing.assert_close(dw_test, dw_ref, **tols) + + +def _test_linear( + *, + bias: bool = True, + local_weight_shape: tuple[int, int] = (16, 16), + batch_size: int = 16, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + fp8_compute: bool = False, + fp8_input: bool = False, + fp8_weight: bool = False, + fp8_grad_output: bool = False, + tensor_parallel_mode: str = "column", + sequence_parallel: bool = False, +) -> None: + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + local_out_features, local_in_features = local_weight_shape + out_features, in_features = local_out_features, local_in_features + if tensor_parallel_mode == "column": + out_features *= world_size + elif tensor_parallel_mode == "row": + in_features *= world_size + in_shape = [batch_size, in_features] + out_shape = [batch_size, out_features] + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_input), + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_weight), + ) + b_ref, b_test = None, None + if bias: + if tensor_parallel_mode == "row": + bias_shape = [world_size, out_features] + else: + bias_shape = [out_features] + b_ref, b_test = make_reference_and_test_tensors( + bias_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_grad_output), + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) + if bias: + if tensor_parallel_mode == "row": + y_ref += b_ref.sum(dim=0) + else: + y_ref += b_ref + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + dw_ref = w_ref.grad + db_ref = b_ref.grad if bias else None + dx_ref = x_ref.grad + if tensor_parallel_mode == "column": + local_out_features = out_features // world_size + local_slice = slice( + rank * local_out_features, + (rank + 1) * local_out_features, + ) + w_ref = w_ref[local_slice, :] + dw_ref = dw_ref[local_slice, :] + w_test = w_test[local_slice, :] + if bias: + b_ref = b_ref[local_slice] + db_ref = db_ref[local_slice] + b_test = b_test[local_slice] + y_ref = y_ref[..., local_slice] + dy_ref = dy_ref[..., local_slice] + dy_test = dy_test[..., local_slice].clone() + elif tensor_parallel_mode == "row": + local_in_features = in_features // world_size + local_slice = slice( + rank * local_in_features, + (rank + 1) * local_in_features, + ) + w_ref = w_ref[:, local_slice] + dw_ref = dw_ref[:, local_slice] + w_test = w_test[:, local_slice] + if bias: + b_ref = b_ref[rank, :] + db_ref = db_ref[rank, :] + b_test = b_test[rank, :] + x_ref = x_ref[..., local_slice] + dx_ref = dx_ref[..., local_slice] + x_test = x_test[..., local_slice].clone() + if sequence_parallel: + local_batch_size = batch_size // world_size + local_slice = slice( + rank * local_batch_size, + (rank + 1) * local_batch_size, + ) + if tensor_parallel_mode == "column": + x_ref = x_ref[local_slice, ...] + dx_ref = dx_ref[local_slice, ...] + x_test = x_test[local_slice, ...].clone() + elif tensor_parallel_mode == "row": + y_ref = y_ref[local_slice, ...] + dy_ref = dy_ref[local_slice, ...] + dy_test = dy_test[local_slice, ...].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + with te.fp8_model_init(enabled=fp8_weight): + model = te_ops.Sequential( + te_ops.Linear( + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + tensor_parallel_mode=tensor_parallel_mode, + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + ), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + if bias: + model[0].bias.copy_(b_test) + del w_test + del b_test + with te.fp8_autocast(enabled=fp8_compute): + y_test = model(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + model[0].weight._fp8_dtype + if is_float8_tensor(model[0].weight) + else tex.DType.kFloat8E4M3 + ) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + torch.testing.assert_close(dw_test, dw_ref, **tols) + if bias: + db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, db_ref, **tols) + + +def _test_fp8_scale_update( + *, + amax_history_len: int = 31, + amax_compute_algo: str = "max", + margin: float = 2, + local_weight_shape: tuple[int, int] = (16, 16), + batch_size: int = 16, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + tensor_parallel_mode: str = "column", +) -> None: + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + local_out_features, local_in_features = local_weight_shape + out_features, in_features = local_out_features, local_in_features + if tensor_parallel_mode == "column": + out_features *= world_size + elif tensor_parallel_mode == "row": + in_features *= world_size + in_shape = [batch_size, in_features] + out_shape = [batch_size, out_features] + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + def ref_amax_and_scale( + ref: torch.Tensor, + stage: str, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Expected absmax and FP8 scale""" + amax = ref.abs().amax() + max_val = { + "forward": 448.0, + "backward": 57344.0, + }[stage] + scale = (max_val / amax) / (2**margin) + amax = amax.to(dtype=torch.float32, device="cpu") + scale = scale.to(dtype=torch.float32, device="cpu") + return amax, scale + + # Compute expected amaxes and FP8 scales + x_amax_ref, x_scale_ref = ref_amax_and_scale(x_ref, "forward") + w_amax_ref, w_scale_ref = ref_amax_and_scale(w_ref, "forward") + dy_amax_ref, dy_scale_ref = ref_amax_and_scale(dy_ref, "backward") + + # Convert to distributed tensors + with torch.no_grad(): + if tensor_parallel_mode == "column": + local_out_features = out_features // world_size + local_slice = slice( + rank * local_out_features, + (rank + 1) * local_out_features, + ) + w_ref = w_ref[local_slice, :] + w_test = w_test[local_slice, :] + dy_ref = dy_ref[..., local_slice] + dy_test = dy_test[..., local_slice].clone() + elif tensor_parallel_mode == "row": + local_in_features = in_features // world_size + local_slice = slice( + rank * local_in_features, + (rank + 1) * local_in_features, + ) + w_ref = w_ref[:, local_slice] + w_test = w_test[:, local_slice] + x_ref = x_ref[..., local_slice] + x_test = x_test[..., local_slice].clone() + x_test.requires_grad_() + + # Initialize fusible operation + op = te_ops.BasicLinear( + in_features, + out_features, + device=device, + dtype=dtype, + tensor_parallel_mode=tensor_parallel_mode, + tensor_parallel_group=process_group, + ) + with torch.no_grad(): + op.weight.copy_(w_test) + del w_test + + # Forward and backward pass + fp8_format = transformer_engine.common.recipe.Format.HYBRID + recipe = transformer_engine.common.recipe.DelayedScaling( + margin=margin, + interval=1, + fp8_format=fp8_format, + amax_history_len=amax_history_len, + amax_compute_algo=amax_compute_algo, + ) + with te.fp8_autocast(fp8_recipe=recipe): + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + x_fp8_meta = op.get_fp8_meta("input")[forward_key] + w_fp8_meta = op.get_fp8_meta("param")[forward_key] + dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key] + x_amax_test = x_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") + w_amax_test = w_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") + dy_amax_test = dy_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") + x_scale_test = x_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") + w_scale_test = w_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") + dy_scale_test = dy_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") + torch.testing.assert_close(x_amax_test, x_amax_ref) + torch.testing.assert_close(w_amax_test, w_amax_ref) + torch.testing.assert_close(dy_amax_test, dy_amax_ref) + torch.testing.assert_close(x_scale_test, x_scale_ref) + torch.testing.assert_close(w_scale_test, w_scale_ref) + torch.testing.assert_close(dy_scale_test, dy_scale_ref) + + +def run_parallel_tests() -> None: + """Run parallel tests""" + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Collective communication ops + if rank == 0: + print(f"Running _test_all_reduce") + _test_all_reduce() + if rank == 0: + print(f"Running _test_all_gather") + _test_all_gather() + if rank == 0: + print(f"Running _test_reduce_scatter") + _test_reduce_scatter() + + # Basic linear op + for config in itertools.product( + (False, True) if fp8_available else (False,), + ("column", "row"), + (False, True), + ): + if rank == 0: + print(f"Running _test_basic_linear with {config=}") + fp8, tensor_parallel_mode, sequence_parallel = config + _test_basic_linear( + fp8_compute=fp8, + fp8_input=fp8, + fp8_weight=fp8, + fp8_grad_output=fp8, + tensor_parallel_mode=tensor_parallel_mode, + sequence_parallel=sequence_parallel, + ) + + # Linear op + for config in itertools.product( + (False, True) if fp8_available else (False,), + ("column", "row"), + ): + if rank == 0: + print(f"Running _test_linear with {config=}") + fp8, tensor_parallel_mode = config + dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 + _test_linear( + bias=True, # bias=False is tested in _test_basic_linear + dtype=dtype, + fp8_compute=fp8, + fp8_input=fp8, + fp8_weight=fp8, + fp8_grad_output=fp8, + tensor_parallel_mode=tensor_parallel_mode, + ) + + # FP8 scale update + if fp8_available: + if rank == 0: + print(f"Running _test_fp8_scale_update") + _test_fp8_scale_update() + + +# Parallel job sizes +_world_sizes = [torch.cuda.device_count()] +if 1 not in _world_sizes: + _world_sizes.append(1) +if torch.cuda.device_count() >= 2 and 2 not in _world_sizes: + _world_sizes.append(2) + + +@pytest.mark.parametrize("world_size", _world_sizes) +def test_distributed_fuser_ops(world_size: int) -> None: + """Launch parallel job that runs parallel tests""" + python_exe = pathlib.Path(sys.executable).resolve() + current_file = pathlib.Path(__file__).resolve() + command = [ + python_exe, + "-m", + "torch.distributed.run", + f"--nproc_per_node={world_size}", + current_file, + "--parallel", + ] + result = subprocess.run( + command, + check=True, + ) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", action="store_true", help="Run parallel tests") + args = parser.parse_args() + if args.parallel: + run_parallel_tests() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 8c854b65fb..0c2118718c 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -from typing import Optional +from typing import Iterable, Optional import pytest import torch @@ -15,6 +15,8 @@ _amax_and_scale_update, get_default_fp8_recipe, ) +import transformer_engine.pytorch.ops as te_ops +import transformer_engine_torch as tex # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -33,7 +35,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("amax_history_len", [31, 1024]) @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"]) @pytest.mark.parametrize("is_first_microbatch", [None, True, False]) - def test_amax_and_scale_update( + def test_fp8_scale_update_with_linear_module( self, amax_history_len: int, amax_compute_algo: str, @@ -49,7 +51,7 @@ def test_amax_and_scale_update( amax_history_len=amax_history_len, amax_compute_algo=amax_compute_algo, ) - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + with te.fp8_autocast(fp8_recipe=recipe): module = te.Linear(16, 16) y = module( torch.randn([16, 16], device="cuda"), @@ -162,6 +164,130 @@ def test_amax_and_scale_update( ref_scale_inv_backward[0], ) + @pytest.mark.parametrize("amax_history_len", [31, 1024]) + @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"]) + def test_fp8_scale_update_with_linear_fuser_op( + self, + amax_history_len: int, + amax_compute_algo: str, + margin: float = 2, + num_steps: int = 4, + in_shape: tuple[int] = (16, 16), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ): + + # Construct linear op + op = te_ops.BasicLinear(in_shape[-1], in_shape[-1]) + + # Get FP8 meta tensors + forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + x_fp8_meta = op.get_fp8_meta("input")[forward_key] + w_fp8_meta = op.get_fp8_meta("param")[forward_key] + dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key] + + # Perform training steps + x_history = [] + w_history = [] + dy_history = [] + for step in range(num_steps): + + # Fill tensors with known values + x_history.append(step + 0.25) + w_history.append(step + 0.5) + dy_history.append(step + 0.75) + x = torch.full( + in_shape, + x_history[-1], + dtype=dtype, + device=device, + requires_grad=True, + ) + dy = torch.full( + in_shape, + dy_history[-1], + dtype=dtype, + device=device, + ) + with torch.no_grad(): + op.weight.fill_(w_history[-1]) + + # Forward and backward pass + fp8_format = transformer_engine.common.recipe.Format.HYBRID + recipe = transformer_engine.common.recipe.DelayedScaling( + margin=margin, + interval=1, + fp8_format=fp8_format, + amax_history_len=amax_history_len, + amax_compute_algo=amax_compute_algo, + ) + with te.fp8_autocast(fp8_recipe=recipe): + y = op(x) + y.backward(dy) + + def check_amax_history( + fp8_meta: dict, + ref_amax_history: Iterable[float], + ) -> None: + """Check that amax history matches expected values""" + if len(ref_amax_history) > amax_history_len: + ref_amax_history = ref_amax_history[-amax_history_len:] + ref_amax_history = torch.tensor( + ref_amax_history, + dtype=torch.float32, + device=device, + ) + test_amax_history = fp8_meta.amax_history[:, 0] + tols = dict(rtol=0, atol=0) + torch.testing.assert_close( + test_amax_history[-(step + 1) :], + ref_amax_history[: (step + 1)], + **tols, + ) + + def check_scale( + fp8_meta: dict, + ref_amax_history: Iterable[float], + stage: str, + ): + """Check that scale and scale reciprocal match expected values""" + + # Compute amax + if len(ref_amax_history) > amax_history_len: + ref_amax_history = ref_amax_history[-(amax_history_len + 1) :] + if amax_compute_algo == "max": + ref_amax = max(ref_amax_history) + elif amax_compute_algo == "most_recent": + ref_amax = ref_amax_history[-1] + else: + raise RuntimeError(f"{amax_compute_algo=} is not supported") + + # Compute scale + max_val = { + "forward": 448.0, + "backward": 57344.0, + }[stage] + ref_scale = (max_val / ref_amax) / (2**margin) + + # Check values in FP8 meta tensors + torch.testing.assert_close( + fp8_meta.scale.item(), + ref_scale, + ) + torch.testing.assert_close( + fp8_meta.scale_inv.item(), + 1 / ref_scale, + ) + + # Check that results match expected values + check_amax_history(x_fp8_meta, x_history) + check_amax_history(w_fp8_meta, w_history) + check_amax_history(dy_fp8_meta, dy_history) + check_scale(x_fp8_meta, x_history, "forward") + check_scale(w_fp8_meta, w_history, "forward") + check_scale(dy_fp8_meta, dy_history, "backward") + @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"]) @pytest.mark.parametrize( @@ -191,7 +317,7 @@ def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype # Setup fp8_meta dictionary def setup_fp8_meta(): - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + with te.fp8_autocast(fp8_recipe=recipe): module = te.Linear(16, 16) y = module(torch.zeros([16, 16], device="cuda")) y.backward(torch.zeros_like(y)) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index a2c5620f36..fdf65db21e 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -3,9 +3,11 @@ # See LICENSE for license information. """Methods needed for distributed training (DP/TP).""" -import warnings +from __future__ import annotations + from contextlib import contextmanager, AbstractContextManager, ContextDecorator -from typing import Any, Dict, Union, Optional, Callable, Tuple, List +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import warnings import torch from torch.cuda import _lazy_call, _lazy_init @@ -829,23 +831,48 @@ def reduce_scatter_along_first_dim( def gather_along_first_dim( - input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False -) -> Tuple[torch.Tensor, torch.Tensor]: - """Gather tensors and concatinate along the first dimension.""" + input_: torch.Tensor, + process_group: dist_group_type, + async_op: bool = False, +) -> tuple[torch.Tensor, Any]: + """All-gather tensors and concatenate along first dimension.""" - world_size = get_distributed_world_size(tp_group) - # Bypass the function if we are using only 1 GPU. + # Return immediately if no communication is required + world_size = get_distributed_world_size(process_group) if world_size == 1: return input_, None - dim_size = list(input_.size()) - dim_size[0] = dim_size[0] * world_size + # Allocate output tensor + output_shape = list(input_.size()) + output_shape[0] *= world_size + if isinstance(input_, Float8Tensor): + output = Float8Tensor.make_like( + input_, + data=torch.empty( + output_shape, + dtype=torch.uint8, + device=input_.device, + ), + ) + src = input_._data.contiguous() + dst = output._data + else: + output = torch.empty( + output_shape, + dtype=input_.dtype, + device=input_.device, + memory_format=torch.contiguous_format, + ) + src = input_.contiguous() + dst = output - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + # Launch all-gather handle = torch.distributed.all_gather_into_tensor( - output, input_.contiguous(), group=tp_group, async_op=async_op + dst, + src, + group=process_group, + async_op=async_op, ) - return output, handle diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index a38c88cf31..b7f87ad397 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -563,6 +563,23 @@ def expand_as(self, other: torch.Tensor): return _IdentityFunc.apply(self) return super().expand_as(other) + def contiguous( + self, + *, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> Float8Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._data.is_contiguous(memory_format=memory_format): + return self + return _IdentityFunc.apply( + self, + {"data": self._data.detach().contiguous(memory_format=memory_format)}, + ) + def transpose_2d( self, *, @@ -885,6 +902,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_attrs=args[0]._fp8_attrs, ) + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._data + data_view = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like( + tensor, + data=data_view, + fp8_attrs=tensor._fp8_attrs, + ) + def maybe_unwrap(t): if isinstance(t, Float8Tensor): return t.from_float8() diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py new file mode 100644 index 0000000000..ec3d4fd315 --- /dev/null +++ b/transformer_engine/pytorch/ops/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operations. + +This operation-based API is experimental and subject to change. + +""" + +from transformer_engine.pytorch.ops.basic import ( + AllGather, + AllReduce, + BasicLinear, + Bias, + Identity, + ReduceScatter, + Reshape, +) +from transformer_engine.pytorch.ops.linear import Linear +from transformer_engine.pytorch.ops.op import FusibleOperation +from transformer_engine.pytorch.ops.sequential import Sequential diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py new file mode 100644 index 0000000000..77efef4ab6 --- /dev/null +++ b/transformer_engine/pytorch/ops/_common.py @@ -0,0 +1,152 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Helper functions used in fusible operations.""" + +from __future__ import annotations +from typing import Any, Iterable, Optional + +import torch + +from transformer_engine.pytorch.float8_tensor import Float8Tensor + + +def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: + """Canonicalize PyTorch device + + If `None`, then returns the default CUDA device. + + """ + if device is None: + # Use default CUDA device + device = torch.get_default_device() + if device.type != "cuda": + device = torch.device("cuda", torch.cuda.current_device()) + elif not isinstance(device, torch.device): + device = torch.device(device) + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + return device + + +def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype: + """Canonicalize PyTorch datatype + + If `None`, then returns the default PyTorch datatype. + + """ + if dtype is None: + # Use default dtype + dtype = torch.get_default_dtype() + return dtype + + +def devices_match(device1: torch.device, device2: torch.device) -> bool: + """Whether two devices are the same""" + device1 = torch.device(device1) + device2 = torch.device(device2) + if device1.type != device2.type: + return False + if device1.type == "cuda": + index1 = device1.index + index2 = device2.index + if index1 is None: + index1 = torch.cuda.current_device() + if index2 is None: + index2 = torch.cuda.current_device() + return index1 == index2 + return device1 == device2 + + +def is_float8_tensor(tensor: Any) -> bool: + """Check if object is a `Float8Tensor`""" + return isinstance(tensor, Float8Tensor) + + +def convert_tensor( + tensor: torch.Tensor | Float8Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + memory_format: torch.memory_format = torch.preserve_format, +) -> torch.Tensor | Float8Tensor: + """Convert tensor attributes, keeping same data if possible""" + + # Default kwargs + if device is None: + device = tensor.device + device = canonicalize_device(device) + if dtype is None: + dtype = tensor.dtype + dtype = canonicalize_dtype(dtype) + + # Make sure output is detached from autograd graph + tensor = tensor.detach() + + # Return immediately if tensor already has desired attributes + if devices_match(device, tensor.device) and dtype == tensor.dtype: + if memory_format == torch.preserve_format or tensor.is_contiguous( + memory_format=memory_format + ): + return tensor + + # Convert FP8 tensor + if is_float8_tensor(tensor): + data = tensor._data.to(device=device, memory_format=memory_format) + return Float8Tensor.make_like( + tensor, + data=data, + fp8_attrs=tensor._fp8_attrs, + dtype=dtype, + ) + + # Convert standard PyTorch tensor + return tensor.to(device=device, dtype=dtype, memory_format=memory_format) + + +def reshape( + tensor: torch.Tensor | Float8Tensor, + shape: Iterable[int], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor | Float8Tensor: + """Reshape tensor, keeping same data if possible + + If the input is a Float8Tensor, this function attempts to preserve + the cached transpose if available and valid. If a cached transpose + is present, it is interpreted as the transpose of a 2D matrix + where the width matches the innermost tensor dimension. + + """ + + # Make sure tensor is in expected format + tensor = convert_tensor( + tensor, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + + # Return immediately if tensor already has desired shape + shape = list(shape) + if len(shape) == tensor.dim(): + if sum(1 for d in shape if d == -1) > 1: + raise ValueError( + "Attempted to reshape tensor with " + f"shape={tuple(tensor.size())} into shape={tuple(shape)}" + ) + if all(d1 == d2 for d1, d2 in zip(shape, tensor.size()) if d1 != -1): + return tensor + + # Reshape FP8 tensor + # Note: Preserve cached transpose if possible + if is_float8_tensor(tensor): + out = Float8Tensor.make_like( + tensor, + data=tensor._data.view(shape), + fp8_attrs=tensor._fp8_attrs, + ) + return out + + # Reshape standard PyTorch tensor + return tensor.view(shape) diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py new file mode 100644 index 0000000000..3621910c8b --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Single tensor operations supported by the operation fuser.""" + +from .all_gather import AllGather +from .all_reduce import AllReduce +from .basic_linear import BasicLinear +from .bias import Bias +from .identity import Identity +from .reduce_scatter import ReduceScatter +from .reshape import Reshape diff --git a/transformer_engine/pytorch/ops/basic/all_gather.py b/transformer_engine/pytorch/ops/basic/all_gather.py new file mode 100644 index 0000000000..b914d1dc6f --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/all_gather.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for all-gather.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from .._common import convert_tensor, is_float8_tensor + + +class AllGather(BasicOperation): + """All-gather tensor along outer dimension + + Equivalent to gathering tensors from all processes and + concatenating along the first dimension. + + Parameters + ---------- + process_group: torch.distributed.ProcessGroup, default = world group + Process group for communication + + """ + + def __init__( + self, + process_group: Optional[torch.distributed.ProcessGroup] = None, + ) -> None: + super().__init__() + self.process_group: Optional[torch.distributed.ProcessGroup] = process_group + self.process_group_size: int = torch.distributed.get_world_size(process_group) + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Trivial case + if self.process_group_size == 1: + return input_ + + # Tensor dimensions + input_dims = input_.size() + if not input_dims: + raise RuntimeError( + "Attempted to all-gather a tensor " + f"with shape={list(input_dims)} " + f"over {self.process_group_size} processes" + ) + output_dims = list(input_dims) + output_dims[0] *= self.process_group_size + + # Perform all-gather + x = convert_tensor(input_, memory_format=torch.contiguous_format) + y = None + if is_float8_tensor(x): + y = Float8Tensor.make_like( + x, + data=torch.empty( + output_dims, + dtype=torch.uint8, + device=x.device, + ), + ) + torch.distributed.all_gather_into_tensor( + y._data, + x._data, + group=self.process_group, + ) + else: + y = torch.empty(output_dims, dtype=x.dtype, device=x.device) + torch.distributed.all_gather_into_tensor( + y, + x, + group=self.process_group, + ) + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Trivial case + if self.process_group_size == 1: + return grad_output, () + + # Tensor dimensions + output_dims = grad_output.size() + if not output_dims or output_dims[0] % self.process_group_size != 0: + raise RuntimeError( + "Attempted to reduce-scatter a tensor " + f"with shape={list(output_dims)} " + f"over {self.process_group_size} processes" + ) + input_dims = list(output_dims) + input_dims[0] //= self.process_group_size + + # Check output gradient tensor + dy = grad_output + if is_float8_tensor(dy): + dy = dy.from_float8() + dy = dy.contiguous() + + # Perform reduce-scatter + dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device) + torch.distributed.reduce_scatter_tensor( + dx, + dy, + group=self.process_group, + ) + return dx, () diff --git a/transformer_engine/pytorch/ops/basic/all_reduce.py b/transformer_engine/pytorch/ops/basic/all_reduce.py new file mode 100644 index 0000000000..622346b1c5 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/all_reduce.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for all-reduce.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from .._common import is_float8_tensor + + +class AllReduce(BasicOperation): + """All-reduce tensor + + Equivalent to summing tensors from all processes. It is assumed + that the output is used in operations that are redundantly + computed on all processes, and hence that gradients are identical + between processes. + + Parameters + ---------- + process_group: torch.distributed.ProcessGroup, default = world group + Process group for communication + + """ + + def __init__( + self, + process_group: Optional[torch.distributed.ProcessGroup] = None, + reduce_in_backward: bool = True, + ) -> None: + super().__init__() + self.process_group: Optional[torch.distributed.ProcessGroup] = process_group + self._reduce_in_backward: bool = reduce_in_backward + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Trivial case + if torch.distributed.get_world_size(self.process_group) == 1: + return input_ + + # Perform all-reduce + x = input_ + if is_float8_tensor(x): + x = x.from_float8() + x = x.contiguous() + torch.distributed.all_reduce(x, group=self.process_group) + return x + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + return grad_output, () diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py new file mode 100644 index 0000000000..49923e7af8 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -0,0 +1,1047 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for linear layer without bias.""" + +from __future__ import annotations +from collections.abc import Callable, Iterable +import contextlib +import math +from typing import Any, Optional + +import torch + +from transformer_engine.pytorch.cpp_extensions import fp8_gemm, gemm +from transformer_engine.pytorch.distributed import ( + CudaRNGStatesTracker, + gather_along_first_dim, + reduce_scatter_along_first_dim, +) +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + get_fp8_te_dtype, +) +from transformer_engine.pytorch.module.base import get_workspace +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from .._common import ( + canonicalize_device, + canonicalize_dtype, + convert_tensor, + is_float8_tensor, + reshape, +) +from ...utils import clear_tensor_data + + +def _wait_async(handle: Optional[Any]) -> None: + """Wait for asynchronous communication to finish, if needed""" + if handle is not None: + handle.wait() + + +class BasicLinear(BasicOperation): + """Apply linear transformation: :math:`y = x A^T` + + This is a drop-in replacement for `torch.nn.Linear` with + `bias=False`. + + Parameters + ---------- + in_features: int + Inner dimension of input tensor + out_features: int + Inner dimension of output tensor + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + Mode for tensor parallelism + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + sequence_parallel: bool, default = `False` + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors along + outer dimension (sequence or batch dim) when not distributing + along inner dimension (embedding dim) + rng_state_tracker_function: callable + Function that returns `CudaRNGStatesTracker`, which is used + for model-parallel weight initialization + accumulate_into_main_grad: bool, default = `False` + Whether to directly accumulate weight gradients into the + weight's `main_grad` attribute instead of relying on PyTorch + autograd. The weight's `main_grad` must be set externally and + there is no guarantee that `grad` will be set or be + meaningful. + + """ + + def __init__( + self, + in_features: int, + out_features: int, + *, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sequence_parallel: bool = False, + rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, + accumulate_into_main_grad: bool = False, + ) -> None: + super().__init__() + + # Weight tensor dimensions + self.in_features: int = in_features + self.out_features: int = out_features + + # Weight tensor device + defer_param_init = False + device = canonicalize_device(device) + if device.type == "meta": + defer_param_init = True + device = canonicalize_device(None) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + self.device: torch.device = device + + # Weight tensor datatype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + self.dtype: torch.dtype = canonicalize_dtype(dtype) + + # Tensor parallel configuration + self.tensor_parallel_mode: Optional[str] + self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup] + self.tensor_parallel_size: int + self.sequence_parallel: bool + self.local_in_features: int + self.local_out_features: int + ( + self.tensor_parallel_mode, + self.tensor_parallel_group, + self.tensor_parallel_size, + self.sequence_parallel, + self.local_in_features, + self.local_out_features, + ) = self._canonicalize_tensor_parallelism( + mode=tensor_parallel_mode, + process_group=tensor_parallel_group, + sequence_parallel=sequence_parallel, + in_features=in_features, + out_features=out_features, + ) + + # Whether weight tensor is natively in FP8 + self._with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + if self._with_fp8_parameters: + self._fp8_metas = self._make_fp8_metas() + + # Initialize parameters if needed + weight = torch.empty( + self.local_out_features, + self.local_in_features, + device="meta", + dtype=dtype, + ) + weight = torch.nn.Parameter(weight) + self.weight: torch.nn.Parameter + self.register_parameter("weight", weight) + self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] + self._rng_state_tracker_function = rng_state_tracker_function + if not defer_param_init: + self.reset_parameters() + + # Whether to accumulate weight gradient into main_grad + self._accumulate_into_main_grad = accumulate_into_main_grad + + @classmethod + def _canonicalize_tensor_parallelism( + cls, + *, + mode: Optional[str], + process_group: Optional[torch.distributed.ProcessGroup], + sequence_parallel: bool, + in_features: int, + out_features: int, + ) -> tuple[ + Optional[str], + Optional[torch.distributed.ProcessGroup], + int, + bool, + int, + int, + ]: + """Check configuration for tensor parallelism + + Parameters + ---------- + mode: {`None`, "column", "row"} + Mode for tensor parallelism + process_group: torch.distributed.ProcessGroup + Process group for tensor parallelism + sequence_parallel: bool + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors + along outer dimension (sequence or batch dim) when not + distributing along inner dimension (embedding dim) + in_features: int + Inner dimension of global input tensor + out_features: int + Inner dimension of global output tensor + + Returns + ------- + mode: {`None`, "column", "row"} + Mode for tensor parallelism + process_group: torch.distributed.ProcessGroup + Process group for tensor parallelism + group_size: int + Size of tensor-parallel process group + sequence_parallel: bool + Whether to apply sequence parallelism + local_in_features: int + Inner dimension of local input tensor + local_out_features: int + Inner dimension of local output tensor + + """ + + # Tensor-parallel group size + if mode is None: + group_size = 1 + else: + group_size = torch.distributed.get_world_size(process_group) + + # Disable tensor parallelism if not needed + if group_size == 1: + mode = None + process_group = None + sequence_parallel = False + + # Determine local tensor dims + local_in_features = in_features + local_out_features = out_features + if mode is None: + pass + elif mode == "column": + # Distribute output tensor + if out_features % group_size != 0: + raise ValueError( + "Invalid configuration for tensor parallelism " + f"({mode=}, {out_features=}, {group_size=})" + ) + local_out_features //= group_size + elif mode == "row": + # Distribute input tensor + if in_features % group_size != 0: + raise ValueError( + "Invalid configuration for tensor parallelism " + f"({mode=}, {in_features=}, {group_size=})" + ) + local_in_features //= group_size + else: + raise ValueError( + "Supported modes for tensor parallelism are " + f'`None`, "row", and "column" (got {mode=})' + ) + + return ( + mode, + process_group, + group_size, + sequence_parallel, + local_in_features, + local_out_features, + ) + + def num_fp8_scales(self, mode: str) -> int: + if mode in ("input", "param", "grad_output"): + return 1 + return 0 + + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + # Make sure parameter is initialized + weight = self.weight + if weight.device.type != "cuda" or is_float8_tensor(weight): + weight = torch.empty_like(weight, device=self.device) + weight = weight.to(device=self.device, dtype=self.dtype) + + # Initialize values + init_context = contextlib.nullcontext + if self._rng_state_tracker_function is not None: + init_context = self._rng_state_tracker_function().fork + with init_context(): + torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + + # Cast to FP8 if needed + if self._with_fp8_parameters: + weight = Float8Tensor.to_float8( + weight, + fp8_meta=self.get_fp8_meta("param"), + fp8_meta_index=0, + ) + + # Save updated parameter + if not isinstance(weight, torch.nn.Parameter): + weight = torch.nn.Parameter(weight) + self.weight = weight + + def pre_forward(self) -> None: + super().pre_forward() + if self.weight.device.type == "meta": + self.reset_parameters() + + @staticmethod + def _functional_forward( + input: torch.Tensor, # pylint: disable=redefined-builtin + weight: torch.Tensor, + *, + bias: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sequence_parallel: bool = False, + with_fp8_compute: bool = False, + input_fp8_meta: Optional[dict[str, Any]] = None, + weight_fp8_meta: Optional[dict[str, Any]] = None, + output_fp8_meta: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Functional API for forward pass + + Parameters + ---------- + input: torch.Tensor + Input tensor + weight: torch.Tensor + Weight tensor + bias: torch.Tensor, optional + Bias tensor + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + Mode for tensor parallelism + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + sequence_parallel: bool, default = `False` + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors + along outer dimension (sequence or batch dim) when not + distributing along inner dimension (embedding dim) + with_fp8_compute: bool, default = `False` + Whether to perform compute in FP8 + input_fp8_meta: dict, optional + FP8 metadata for casting input tensor to FP8. Required for + FP8 compute if input is not already in FP8. + weight_fp8_meta: dict, optional + FP8 metadata for casting weight tensor to FP8. Required for + FP8 compute if weight is not already in FP8. + output_fp8_meta: dict, optional + FP8 metadata for casting output tensor to FP8 + + Returns + ------- + torch.Tensor + Output tensor + torch.Tensor + Input tensor used in GEMM, possibly cast and reshaped from + provided input tensor + torch.Tensor + Weight tensor used in GEMM, possibly cast and reshaped from + provided weight tensor + + """ + + # Check device + if device is None: + device = weight.device + device = canonicalize_device(device) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + + # Check datatype + if dtype is None: + dtype = weight.dtype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Check tensor dims + input_dims = tuple(input.size()) + weight_dims = tuple(weight.size()) + if len(weight_dims) != 2: + raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") + if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Check if FP8 is enabled + if with_fp8_compute: + if input_fp8_meta is None and not is_float8_tensor(input): + raise ValueError("No FP8 metadata was provided for casting input to FP8") + if weight_fp8_meta is None and not is_float8_tensor(weight): + raise ValueError("No FP8 metadata was provided for casting weight to FP8") + else: + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + with_fp8_output = ( + with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None + ) + + # Check input tensor + x_local = reshape( + input, + (-1, input_dims[-1]), + device=device, + dtype=dtype, + ) + if with_fp8_compute and not is_float8_tensor(x_local): + fp8_dtype = get_fp8_te_dtype( + input_fp8_meta["recipe"], + fprop_tensor=True, + ) + x_fp8 = Float8Tensor( + data=torch.empty_like(x_local, dtype=torch.uint8), + fp8_meta=input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), + dtype=dtype, + ) + with_cast_transpose = weight.requires_grad + if tensor_parallel_mode == "column" and sequence_parallel: + with_cast_transpose = False + if with_cast_transpose: + x_fp8.cast_transpose_(x_local) + else: + x_fp8.copy_(x_local) + x_local = x_fp8 + elif not with_fp8_compute and is_float8_tensor(x_local): + x_local = x_local.from_float8() + x = x_local + x_async = None + if tensor_parallel_mode == "column" and sequence_parallel: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + ) + + # Check weight tensor + w = convert_tensor( + weight, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + if with_fp8_compute and not is_float8_tensor(w): + fp8_dtype = get_fp8_te_dtype( + weight_fp8_meta["recipe"], + fprop_tensor=True, + ) + w = Float8Tensor.to_float8( + w, + fp8_meta=weight_fp8_meta, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + elif not with_fp8_compute and is_float8_tensor(w): + w = w.from_float8() + + # Check bias tensor + b = None + if bias is not None: + b = convert_tensor( + bias, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + + # Construct output tensor + y = None + if with_fp8_output: + fp8_dtype = get_fp8_te_dtype( + output_fp8_meta["recipe"], + fprop_tensor=True, + ) + data = torch.empty( + (x.size(0), weight_dims[0]), + dtype=torch.uint8, + device=device, + ) + y = Float8Tensor( + data=data, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + y = torch.empty( + (x.size(0), weight_dims[0]), + dtype=dtype, + device=device, + ) + + # Perform GEMM + _wait_async(x_async) + x_async = None + if with_fp8_compute: + kwargs = dict( + out=y, + bias=b, + use_bias=(b is not None), + ) + if with_fp8_output: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=y._fp8_meta_forward, + ) + kwargs.update( + dict( + out=y._data, + out_index=y._fp8_meta_index, + fp8_meta_tensor=y._fp8_meta[fp8_meta_key], + D_dtype=y._fp8_dtype, + ) + ) + fp8_gemm( + w._data, + w._scale_inv, + 0, + w._fp8_dtype, + x._data, + x._scale_inv, + 0, + x._fp8_dtype, + y.dtype, + get_workspace(), + **kwargs, + ) + else: + gemm( + w, + x, + y.dtype, + get_workspace(), + out=y, + bias=b, + use_bias=(b is not None), + ) + + # Reduce tensor-parallel output if needed + if tensor_parallel_mode == "row": + if sequence_parallel: + y, _ = reduce_scatter_along_first_dim(y, tensor_parallel_group) + else: + torch.distributed.all_reduce(y, group=tensor_parallel_group) + + # Reshape output tensor + output_dims = list(input_dims) + output_dims[0] = -1 + output_dims[-1] = weight_dims[0] + output = reshape(y, output_dims) + + return output, x_local, w + + @staticmethod + def _functional_backward( + grad_output: torch.Tensor, + input: Optional[torch.Tensor], # pylint: disable=redefined-builtin + weight: Optional[torch.Tensor], + input_dims: Iterable[int], + weight_dims: Iterable[int], + *, + input_requires_grad: bool = True, + weight_requires_grad: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sequence_parallel: bool = False, + with_fp8_compute: bool = False, + input_fp8_meta: Optional[dict[str, Any]] = None, + weight_fp8_meta: Optional[dict[str, Any]] = None, + grad_output_fp8_meta: Optional[dict[str, Any]] = None, + grad_input_fp8_meta: Optional[dict[str, Any]] = None, + accumulate_into_grad_weight: bool = False, + grad_weight: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Functional API for backward pass + + Parameters + ---------- + grad_output: torch.Tensor + Loss gradient w.r.t. output tensor + input: torch.Tensor, optional + Input tensor. Required to compute loss gradient w.r.t. + weight. + weight: torch.Tensor, optional + Weight tensor. Required to compute loss gradient w.r.t. + input. + input_dims: iterable of int + Input tensor dimensions + weight_dims: iterable of int + Weight tensor dimensions + input_requires_grad: bool + Whether to compute loss gradient w.r.t. input tensor + weight_requires_grad: bool + Whether to compute loss gradient w.r.t. weight tensor + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + Mode for tensor parallelism + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + sequence_parallel: bool, default = `False` + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors + along outer dimension (sequence or batch dim) when not + distributing along inner dimension (embedding dim) + with_fp8_compute: bool, default = `False` + Whether to perform compute in FP8 + input_fp8_meta: dict, optional + FP8 metadata for casting input tensor to FP8. Required for + FP8 compute if input is not already in FP8. + weight_fp8_meta: dict, optional + FP8 metadata for casting weight tensor to FP8. Required for + FP8 compute if weight is not already in FP8. + grad_output_fp8_meta: dict, optional + FP8 metadata for casting loss gradient w.r.t. output + tensor to FP8. Required if output grad is not already in + FP8. + grad_output_fp8_meta: dict, optional + FP8 metadata for casting loss gradient w.r.t. input + tensor to FP8 + accumulate_into_grad_weight: bool, default = `False` + Accumulate into weight grad instead of overwriting + grad_weight: torch.Tensor, optional + Loss gradient w.r.t. weight tensor + + Returns + ------- + torch.Tensor + Loss gradient w.r.t. input tensor + torch.Tensor + Loss gradient w.r.t. weight tensor + + """ + + # Check device + if device is None: + device = weight.device + device = canonicalize_device(device) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + + # Check datatype + if dtype is None: + dtype = weight.dtype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Check tensor dims + output_dims = tuple(grad_output.size()) + input_dims = tuple(input_dims) + weight_dims = tuple(weight_dims) + if len(weight_dims) != 2: + raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") + if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + if weight_dims[0] != output_dims[-1]: + raise ValueError( + f"Grad output tensor (shape={output_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Check if FP8 is enabled + if with_fp8_compute: + if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): + raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") + else: + input_fp8_meta = None + weight_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + with_fp8_grad_input = ( + with_fp8_compute + and input_requires_grad + and tensor_parallel_mode != "column" + and grad_input_fp8_meta is not None + ) + + # Check grad output tensor + dy_async = None + dy = reshape( + grad_output, + (-1, output_dims[-1]), + device=device, + dtype=dtype, + ) + if with_fp8_compute and not is_float8_tensor(dy): + fp8_dtype = get_fp8_te_dtype( + grad_output_fp8_meta["recipe"], + fprop_tensor=False, + ) + dy_fp8 = Float8Tensor( + data=torch.empty_like(dy, dtype=torch.uint8), + fp8_meta=grad_output_fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), + dtype=dtype, + ) + with_cast_transpose = weight_requires_grad + if tensor_parallel_mode == "row" and sequence_parallel: + with_cast_transpose = False + if with_cast_transpose: + dy_fp8.cast_transpose_(dy) + else: + dy_fp8.copy_(dy) + dy = dy_fp8 + elif not with_fp8_compute and is_float8_tensor(dy): + dy = dy.from_float8() + if tensor_parallel_mode == "row" and sequence_parallel: + dy, dy_async = gather_along_first_dim( + dy, + tensor_parallel_group, + async_op=True, + ) + + # Check input tensor + x = None + x_async = None + if weight_requires_grad: + if input is None: + raise ValueError("Input tensor is required to compute weight grad") + x_local = reshape( + input, + (-1, input_dims[-1]), + device=device, + dtype=dtype, + ) + if with_fp8_compute and not is_float8_tensor(x_local): + fp8_dtype = get_fp8_te_dtype( + input_fp8_meta["recipe"], + fprop_tensor=True, + ) + x_fp8 = Float8Tensor( + data=torch.empty_like(x_local, dtype=torch.uint8), + fp8_meta=input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), + dtype=dtype, + ) + x_fp8.cast_transpose_(x_local) + x_local = x_fp8 + elif not with_fp8_compute and is_float8_tensor(x_local): + x_local = x_local.from_float8() + x = x_local + if tensor_parallel_mode == "column" and sequence_parallel: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + ) + + # Compute grad input + dx = None + dx_async = None + if input_requires_grad: + + # Check weight tensor + if weight is None: + raise ValueError("Weight tensor is required to compute input grad") + w = convert_tensor( + weight, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + if with_fp8_compute and not is_float8_tensor(w): + fp8_dtype = get_fp8_te_dtype( + weight_fp8_meta["recipe"], + fprop_tensor=True, + ) + w_fp8 = Float8Tensor( + data=torch.empty_like(w, dtype=torch.uint8), + fp8_meta=weight_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), + dtype=dtype, + ) + w_fp8.cast_transpose_(w) + w = w_fp8 + elif not with_fp8_compute and is_float8_tensor(w): + w = w.from_float8() + + # Construct grad input tensor + if with_fp8_grad_input: + fp8_dtype = get_fp8_te_dtype( + grad_input_fp8_meta["recipe"], + fprop_tensor=False, + ) + data = torch.empty( + (dy.size(0), weight_dims[1]), + dtype=torch.uint8, + device=device, + ) + dx = Float8Tensor( + data=data, + fp8_meta=grad_input_fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + dx = torch.empty( + (dy.size(0), weight_dims[1]), + dtype=dtype, + device=device, + ) + + # Perform dgrad GEMM + _wait_async(dy_async) + dy_async = None + if with_fp8_compute: + kwargs = dict(out=dx) + if with_fp8_grad_input: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dx._fp8_meta_forward, + ) + kwargs.update( + dict( + out=dx._data, + out_index=dx._fp8_meta_index, + fp8_meta_tensor=dx._fp8_meta[fp8_meta_key], + D_dtype=dx._fp8_dtype, + ) + ) + fp8_gemm( + w.transpose_2d(), + w._scale_inv, + 0, + w._fp8_dtype, + dy._data, + dy._scale_inv, + 0, + dy._fp8_dtype, + dx.dtype, + get_workspace(), + **kwargs, + ) + else: + gemm( + w, + dy, + dx.dtype, + get_workspace(), + layout="NN", + out=dx, + ) + + # Reduce tensor-parallel grad input if needed + if tensor_parallel_mode == "column": + if sequence_parallel: + dx, dx_async = reduce_scatter_along_first_dim( + dx, + tensor_parallel_group, + async_op=True, + ) + else: + dx_async = torch.distributed.all_reduce( + dx, + group=tensor_parallel_group, + async_op=True, + ) + + # Perform wgrad GEMM + if not weight_requires_grad: + grad_weight = None + else: + if grad_weight is None: + if accumulate_into_grad_weight: + raise ValueError( + "Attempted to accumulate into grad weight buffer" + "without providing grad weight" + ) + grad_weight = torch.empty( + weight_dims, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + _wait_async(dy_async) + _wait_async(x_async) + dy_async = None + x_async = None + if with_fp8_compute: + fp8_gemm( + x.transpose_2d(), + x._scale_inv, + 0, + x._fp8_dtype, + dy.transpose_2d(), + dy._scale_inv, + 0, + dy._fp8_dtype, + grad_weight.dtype, + get_workspace(), + accumulate=accumulate_into_grad_weight, + out=grad_weight, + ) + else: + gemm( + x, + dy, + x.dtype, + get_workspace(), + accumulate=accumulate_into_grad_weight, + layout="NT", + out=grad_weight, + ) + + # Clean up and return grads + _wait_async(dy_async) + _wait_async(x_async) + _wait_async(dx_async) + grad_input = None + if dx is not None: + grad_input = reshape(dx, input_dims) + return grad_input, grad_weight + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # FP8 metadata + with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + if with_fp8_compute: + input_fp8_meta = self.get_fp8_meta("input") + weight_fp8_meta = self.get_fp8_meta("param") + if next_op is not None and next_op.num_fp8_scales("input") > 0: + output_fp8_meta = next_op.get_fp8_meta("input") + grad_output_fp8_meta = self.get_fp8_meta("grad_output") + if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: + grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + + # Linear forward + output, x_local, _ = BasicLinear._functional_forward( + input=input_, + weight=self.weight, + device=self.device, + dtype=self.dtype, + tensor_parallel_mode=self.tensor_parallel_mode, + tensor_parallel_group=self.tensor_parallel_group, + sequence_parallel=self.sequence_parallel, + with_fp8_compute=with_fp8_compute, + input_fp8_meta=input_fp8_meta, + weight_fp8_meta=weight_fp8_meta, + output_fp8_meta=output_fp8_meta, + ) + + # Save state for backward pass + ctx.save_for_backward(x_local) + ctx.with_fp8_compute = with_fp8_compute + ctx.weight_fp8_meta = weight_fp8_meta + ctx.grad_output_fp8_meta = grad_output_fp8_meta + ctx.grad_input_fp8_meta = grad_input_fp8_meta + ctx.input_dims = input_.size() + ctx.input_requires_grad = input_.requires_grad + ctx.weight_requires_grad = self.weight.requires_grad + ctx.has_prev_op = prev_op is not None + + return output + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]: + + # Saved tensors from forward pass + (x_local,) = ctx.saved_tensors + + # wgrad fusion + accumulate_into_main_grad = self._accumulate_into_main_grad + grad_weight = None + if ctx.weight_requires_grad and accumulate_into_main_grad: + if not hasattr(self.weight, "main_grad"): + raise RuntimeError( + "BasicLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + grad_weight = self.weight.main_grad.detach() + else: + accumulate_into_main_grad = False + + # Linear backward pass + grad_input, grad_weight = BasicLinear._functional_backward( + grad_output=grad_output, + input=x_local, + weight=self.weight, + input_dims=ctx.input_dims, + weight_dims=self.weight.size(), + input_requires_grad=ctx.input_requires_grad, + weight_requires_grad=ctx.weight_requires_grad, + device=self.device, + dtype=self.dtype, + tensor_parallel_mode=self.tensor_parallel_mode, + tensor_parallel_group=self.tensor_parallel_group, + sequence_parallel=self.sequence_parallel, + with_fp8_compute=ctx.with_fp8_compute, + weight_fp8_meta=ctx.weight_fp8_meta, + grad_output_fp8_meta=ctx.grad_output_fp8_meta, + grad_input_fp8_meta=ctx.grad_input_fp8_meta, + accumulate_into_grad_weight=accumulate_into_main_grad, + grad_weight=grad_weight, + ) + + # Clear input tensor if possible + if ctx.has_prev_op: + clear_tensor_data(x_local) + + if accumulate_into_main_grad: + grad_weight = None + return grad_input, [grad_weight] diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py new file mode 100644 index 0000000000..b8e8cc5e56 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -0,0 +1,142 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for bias.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from .._common import ( + canonicalize_device, + canonicalize_dtype, +) + + +class Bias(BasicOperation): + """Apply additive bias + + This is equivalent to the additive bias in `torch.nn.Linear`. + + Parameters + ---------- + size: int + Inner dimension of input tensor + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + tensor_parallel: bool, default = `False` + Whether to distribute input tensor and bias tensors along + inner dimension + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + + """ + + def __init__( + self, + size: int, + *, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + tensor_parallel: bool = False, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + ) -> None: + super().__init__() + + # Bias size + self._size = size + + # Bias tensor device + defer_param_init = False + device = canonicalize_device(device) + if device.type == "meta": + defer_param_init = True + device = canonicalize_device(None) + self.device: torch.device = device + + # Bias tensor datatype + self.dtype: torch.dtype = canonicalize_dtype(dtype) + + # Tensor parallel configuration + tensor_parallel_size = 1 + local_size = size + if tensor_parallel: + tensor_parallel_size = torch.distributed.get_world_size(tensor_parallel_group) + tensor_parallel = tensor_parallel_size > 1 + if size % tensor_parallel_size != 0: + raise ValueError( + "Invalid configuration for tensor parallelism " + f"({size=}, {tensor_parallel_size=})" + ) + local_size //= tensor_parallel_size + else: + tensor_parallel_group = None + self.tensor_parallel: bool = tensor_parallel + self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = tensor_parallel_group + self.tensor_parallel_size: int = tensor_parallel_size + self.local_size: int = local_size + + # Initialize parameters if needed + bias = torch.empty( + local_size, + device="meta", + dtype=dtype, + ) + bias = torch.nn.Parameter(bias) + self.bias: torch.nn.Parameter + self.register_parameter("bias", bias) + if not defer_param_init: + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + # Make sure parameter is initialized + bias = self.bias + if bias.device.type != "cuda": + bias = torch.empty_like(bias, device=self.device) + bias = bias.to(device=self.device, dtype=self.dtype) + + # Initialize values + bias.zero_() + + # Save updated parameter + if not isinstance(bias, torch.nn.Parameter): + bias = torch.nn.Parameter(bias) + self.bias = bias + + def pre_forward(self) -> None: + super().pre_forward() + if self.bias.device.type == "meta": + self.reset_parameters() + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + x = input_ + b = self.bias.reshape([1] * (x.dim() - 1) + [self.local_size]) + return x + b + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + dy = grad_output + if dy.dim() > 1: + db = dy.sum(tuple(range(dy.dim() - 1))) + else: + db = dy + return dy, (db,) diff --git a/transformer_engine/pytorch/ops/basic/identity.py b/transformer_engine/pytorch/ops/basic/identity.py new file mode 100644 index 0000000000..73179c68a6 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/identity.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for identity.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) + + +class Identity(BasicOperation): + """Return input tensor""" + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + return input_ + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + return grad_output, () diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py new file mode 100644 index 0000000000..996ca2da31 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for reduce-scatter.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from .._common import convert_tensor, is_float8_tensor + + +class ReduceScatter(BasicOperation): + """Reduce-scatter tensor along outer dimension + + Equivalent to summing tensors from all processes and splitting + along the first dimension. + + Parameters + ---------- + process_group: torch.distributed.ProcessGroup, default = world group + Process group for communication + + """ + + def __init__( + self, + process_group: Optional[torch.distributed.ProcessGroup] = None, + ) -> None: + super().__init__() + self.process_group: Optional[torch.distributed.ProcessGroup] = process_group + self.process_group_size: int = torch.distributed.get_world_size(process_group) + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Trivial case + if self.process_group_size == 1: + return input_ + + # Tensor dimensions + input_dims = input_.size() + if not input_dims or input_dims[0] % self.process_group_size != 0: + raise RuntimeError( + "Attempted to reduce-scatter a tensor " + f"with shape={list(input_dims)} " + f"over {self.process_group_size} processes" + ) + output_dims = list(input_dims) + output_dims[0] //= self.process_group_size + + # Check input tensor + x = input_ + if is_float8_tensor(x): + x = x.from_float8() + x = x.contiguous() + + # Perform reduce-scatter + y = torch.empty(output_dims, dtype=x.dtype, device=x.device) + torch.distributed.reduce_scatter_tensor(y, x, group=self.process_group) + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Trivial case + if self.process_group_size == 1: + return grad_output, () + + # Tensor dimensions + output_dims = grad_output.size() + if not output_dims: + raise RuntimeError( + "Attempted to all-gather a tensor " + f"with shape={list(output_dims)} " + f"over {self.process_group_size} processes" + ) + input_dims = list(output_dims) + input_dims[0] *= self.process_group_size + + # Perform all-gather + dy = convert_tensor(grad_output, memory_format=torch.contiguous_format) + dx = None + if is_float8_tensor(dy): + dx = Float8Tensor.make_like( + dy, + data=torch.empty( + input_dims, + dtype=torch.uint8, + device=dy.device, + ), + ) + torch.distributed.all_gather_into_tensor( + dx._data, + dy._data, + group=self.process_group, + ) + else: + dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device) + torch.distributed.all_gather_into_tensor( + dx, + dy, + group=self.process_group, + ) + + return dx, () diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py new file mode 100644 index 0000000000..c3b1816635 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for reshape.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from .._common import reshape + + +class Reshape(BasicOperation): + """Reshape tensor + + See `torch.reshape`. + + Parameters + ---------- + shape: iterable of int + Output tensor dimensions. If one dimension is -1, it is + inferred based on input tensor dimensions. + + """ + + def __init__(self, shape: Iterable[int]) -> None: + super().__init__() + self._shape = tuple(shape) + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + ctx.input_shape = input_.size() + return reshape(input_, self._shape) + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + return reshape(grad_output, ctx.input_shape), () diff --git a/transformer_engine/pytorch/ops/fused_forward/__init__.py b/transformer_engine/pytorch/ops/fused_forward/__init__.py new file mode 100644 index 0000000000..ed523a067a --- /dev/null +++ b/transformer_engine/pytorch/ops/fused_forward/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Compound tensor operation supported by the operation fuser.""" + +from .linear_bias_activation import ( + ForwardLinearBiasActivation, + fuse_forward_linear_bias_activation, +) diff --git a/transformer_engine/pytorch/ops/fused_forward/linear_bias_activation.py b/transformer_engine/pytorch/ops/fused_forward/linear_bias_activation.py new file mode 100644 index 0000000000..1504dc4a53 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused_forward/linear_bias_activation.py @@ -0,0 +1,191 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for GEMM, bias, activation in the forward pass.""" + +from __future__ import annotations +from typing import Any, Optional + +import torch + +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.ops.basic import BasicLinear, Bias +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + FusedOperation, + FusibleOperation, + OperationContext, +) + + +class ForwardLinearBiasActivation(FusedOperation): + """Fused GEMM, bias, activation in the forward pass + + Bias and activation are both optional. Row tensor parallelism is + not supported since that requires communication immediately after + the GEMM. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + bias: Optional[Bias], + activation: None, + ) -> None: + + # Basic operations that comprise this fused operation + op_idxs = dict( + linear=0, + bias=None, + activation=None, + ) + ops = [linear] + if bias is not None: + op_idxs["bias"] = len(ops) + ops.append(bias) + if activation is not None: + op_idxs["activation"] = len(ops) + ops.append(activation) + + # Initialize base class + super().__init__(ops) + + # Index of each basic operations + self._op_idxs: dict[str, Optional[int]] = op_idxs + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> torch.Tensor: + + # Get basic operations + idx = self._op_idxs["linear"] + linear_op = self.basic_ops[idx] + linear_op_ctx = basic_op_ctxs[idx] + if self._op_idxs["bias"] is None: + bias_op = None + bias = None + else: + idx = self._op_idxs["bias"] + bias_op = self.basic_ops[idx] + bias = bias_op.bias + if basic_op_kwargs[idx]: + raise ValueError("Bias operation forward does not expect keyword arguments") + if self._op_idxs["activation"] is None: + activation_op = None # pylint: disable=unused-variable + else: + raise NotImplementedError("Activations are not yet supported") + + # FP8 metadata + with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + if with_fp8_compute: + input_fp8_meta = linear_op.get_fp8_meta("input") + weight_fp8_meta = linear_op.get_fp8_meta("param") + next_op = basic_op_next_ops[-1] + if next_op is not None and next_op.num_fp8_scales("input") > 0: + output_fp8_meta = next_op.get_fp8_meta("input") + grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + prev_op = basic_op_prev_ops[0] + if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: + grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + + # Linear forward + output, x_local, _ = BasicLinear._functional_forward( + input=input_, + weight=linear_op.weight, + bias=bias, + device=linear_op.device, + dtype=linear_op.dtype, + tensor_parallel_mode=linear_op.tensor_parallel_mode, + tensor_parallel_group=linear_op.tensor_parallel_group, + sequence_parallel=linear_op.sequence_parallel, + with_fp8_compute=with_fp8_compute, + input_fp8_meta=input_fp8_meta, + weight_fp8_meta=weight_fp8_meta, + output_fp8_meta=output_fp8_meta, + ) + + # Save state for backward pass + linear_op_ctx.save_for_backward(x_local) + linear_op_ctx.with_fp8_compute = with_fp8_compute + linear_op_ctx.weight_fp8_meta = weight_fp8_meta + linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta + linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.input_dims = input_.size() + linear_op_ctx.input_requires_grad = input_.requires_grad + linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad + linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + + return output + + +def fuse_forward_linear_bias_activation( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fuse GEMM, bias, activation in the forward pass + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 2: + out.extend(window) + + # Check if first op is linear + window, ops = ops[:1], ops[1:] + op1, _ = window[0] + if not isinstance(op1, BasicLinear): + continue + if op1.tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after the + # GEMM + continue + if op1.dtype not in (torch.float16, torch.bfloat16): + # cuBLAS only supports fused GEMM+bias+activation with + # FP16 and BF16 output + continue + + # Check if second op is bias + op2, _ = ops[0] + if not isinstance(op2, Bias): + continue + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = ForwardLinearBiasActivation( + linear=window[0][0], + bias=window[1][0], + activation=None, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py new file mode 100644 index 0000000000..06ea608ed8 --- /dev/null +++ b/transformer_engine/pytorch/ops/fuser.py @@ -0,0 +1,270 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Manager class for a pipeline of fusible operations.""" + +from __future__ import annotations +from typing import Any, Optional + +import torch + +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.graph import is_graph_capturing +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + FusibleOperation, + OperationContext, +) +from transformer_engine.pytorch.ops.fused_forward import ( + fuse_forward_linear_bias_activation, +) + + +class _OperationFuserAutogradFunction(torch.autograd.Function): + """Autograd function for a pipeline of operations + + Autograd must be done at the pipeline level since we may apply + different fusions in the forward and backward passes. + + """ + + # pylint: disable=unused-argument + @staticmethod + def forward( + func_ctx: torch.autograd.function.FunctionCtx, + input_: torch.Tensor, + forward_ops: list[tuple[FusibleOperation, list[int]]], + backward_ops: list[tuple[FusibleOperation, list[int]]], + basic_ops: list[BasicOperation], + basic_op_kwargs: list[dict[str, Any]], + *params: torch.nn.Parameter, + ) -> torch.Tensor: + """Forward pass + + Parameters + ---------- + func_ctx: torch.autograd.function.FunctionCtx + Context for PyTorch autograd function + input_: torch.Tensor + Input to first operation in pipeline + forward_ops: list of tuple + Forward pass operations and the indices of the + corresponding basic operations. The order should match + basic_ops. + backward_ops: list of tuple + Backward pass operations and the indices of the + corresponding basic operations. The order should be the + reverse of basic_ops. + basic_ops: list of BasicOperation + Basic operations + basic_op_kwargs: list of dict + Keyword arguments to BasicOperation + *params: torch.nn.Parameter + Parameters in operation pipeline + + """ + + # Operation autograd contexts + basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))] + + # Apply forward ops + x = input_ + requires_grad = x.requires_grad + for op, basic_op_idxs in forward_ops: + + # Forward op + prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] + next_ops = [ + basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs + ] + x = op.fuser_forward( + [basic_op_ctxs[idx] for idx in basic_op_idxs], + x, + prev_ops, + next_ops, + [basic_op_kwargs[idx] for idx in basic_op_idxs], + ) + + # Check if backward op is required + if not requires_grad: + requires_grad = any(param.requires_grad for param in op.parameters()) + for idx in basic_op_idxs: + basic_op_ctxs[idx]._requires_grad = requires_grad + x.requires_grad_(requires_grad=requires_grad) + + # Flatten list of saved tensors + to_save = [] + for ctx in basic_op_ctxs: + range_start = len(to_save) + if ctx.to_save is not None: + to_save.extend(ctx.to_save) + range_end = len(to_save) + ctx.to_save = None + ctx._saved_tensors_range = (range_start, range_end) + func_ctx.save_for_backward(*to_save) + + # Other context for backward pass + func_ctx.backward_ops = backward_ops + func_ctx.basic_ops = basic_ops + func_ctx.basic_op_ctxs = basic_op_ctxs + func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + + return x + + @staticmethod + @torch.autograd.function.once_differentiable + def backward( + func_ctx: Any, + grad_output: torch.Tensor, + ) -> tuple[Optional[torch.Tensor], ...]: + """Backward pass""" + + # Operations and autograd state + backward_ops = func_ctx.backward_ops + basic_ops = func_ctx.basic_ops + basic_op_ctxs = func_ctx.basic_op_ctxs + + # Unflatten list of saved tensors + saved_tensors = func_ctx.saved_tensors + for ctx in basic_op_ctxs: + ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)] + ctx._saved_tensors_range = None + del saved_tensors + + # Apply backward ops + dx = grad_output + grad_params = [None for _ in range(len(basic_ops))] + for op, basic_op_idxs in backward_ops: + + # Stop if no more gradients are required + if all(not basic_op_ctxs[idx]._requires_grad for idx in basic_op_idxs): + dx = None + break + + # Backward op + dx, fused_op_dparams = op.fuser_backward( + [basic_op_ctxs[idx] for idx in basic_op_idxs], + dx, + ) + for idx, basic_op_dparams in zip(basic_op_idxs, fused_op_dparams): + grad_params[idx] = basic_op_dparams + basic_op_ctxs[idx].saved_tensors = None + + # Flatten list of parameter gradients + grad_params_flat = [] + for idx, dparams in enumerate(grad_params): + params = list(basic_ops[idx].parameters()) + if dparams is None: + dparams = [None for _ in range(len(params))] + else: + dparams = list(dparams) + if len(dparams) != len(params): + raise RuntimeError( + f"Expected op {idx} to generate {len(params)} param grads, " + f"but got {len(dparams)}" + ) + grad_params_flat.extend(dparams) + + # Update FP8 scaling factors + if func_ctx.is_first_module and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + + return ( + dx, # input_ + None, # forward_ops + None, # backward_ops + None, # basic_ops + None, # basic_op_kwargs + *grad_params_flat, # params + ) + + +class OperationFuser: + """Manages forward and backward passes for a pipeline of operations + + Parameters + ---------- + ops: list of FusibleOperation + Pipeline of operations + fuse_ops: bool, default = `True` + Whether to attempt fusing operations + + """ + + def __init__( + self, + ops: list[FusibleOperation], + fuse_ops: bool = True, + ) -> None: + + # Get list of basic operations + basic_ops = [] + for op in ops: + if op.is_fused_op: + basic_ops.extend(op.basic_ops) + else: + basic_ops.append(op) + self._num_basic_ops: int = len(basic_ops) + self._basic_ops: list[BasicOperation] = basic_ops + + # Ops for forward and backward pass + self._forward_ops: list[tuple[FusibleOperation, list[int]]] + self._backward_ops: list[tuple[FusibleOperation, list[int]]] + self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)] + self._backward_ops = list(reversed(self._forward_ops)) + + # Fuse ops if needed + if fuse_ops: + self.fuse_ops() + + @classmethod + def _fuse_forward_ops( + cls, + ops: list[tuple[FusibleOperation, list[int]]], + ) -> list[tuple[FusibleOperation, list[int]]]: + """Attempt to fuse operations in forward pass""" + ops = fuse_forward_linear_bias_activation(ops) + return ops + + @classmethod + def _fuse_backward_ops( + cls, + ops: list[tuple[FusibleOperation, list[int]]], + ) -> list[tuple[FusibleOperation, list[int]]]: + """Attempt to fuse operations in backward pass""" + return ops + + def fuse_ops(self) -> None: + """Attempt to fuse operations""" + self._forward_ops = self._fuse_forward_ops(self._forward_ops) + self._backward_ops = self._fuse_backward_ops(self._backward_ops) + + def __call__( + self, + input: torch.Tensor, # pylint: disable=redefined-builtin + basic_op_kwargs: Optional[list[dict[str, Any]]] = None, + ) -> torch.Tensor: + + # Initialization before forward pass + for op in self._basic_ops: + op.pre_forward() + + # Canonicalize op kwargs + if basic_op_kwargs is None: + basic_op_kwargs = [{} for _ in range(len(self._basic_ops))] + + # Flatten list of parameters + params = [] + for op in self._basic_ops: + params.extend(op.parameters()) + + # Fuser forward pass + return _OperationFuserAutogradFunction.apply( + input, + self._forward_ops, + self._backward_ops, + self._basic_ops, + basic_op_kwargs, + *params, + ) diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py new file mode 100644 index 0000000000..13cec30fa2 --- /dev/null +++ b/transformer_engine/pytorch/ops/linear.py @@ -0,0 +1,138 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for linear layer.""" + +from __future__ import annotations +from collections.abc import Callable +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops.basic import ( + AllReduce, + BasicLinear, + Bias, + ReduceScatter, +) +from transformer_engine.pytorch.distributed import CudaRNGStatesTracker +from transformer_engine.pytorch.ops.op import FusedOperation + + +class Linear(FusedOperation): + """Apply linear transformation: :math:`y = x A^T + b` + + This is a drop-in replacement for `torch.nn.Linear`. + + Parameters + ---------- + in_features: int + Inner dimension of input tensor + out_features: int + Inner dimension of output tensor + bias: bool, default = `True` + Apply additive bias + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + Mode for tensor parallelism + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + sequence_parallel: bool, default = `False` + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors along + outer dimension (sequence or batch dim) when not distributing + along inner dimension (embedding dim) + rng_state_tracker_function: callable + Function that returns CudaRNGStatesTracker, which is used for + model-parallel weight initialization + accumulate_into_main_grad: bool, default = `False` + Whether to directly accumulate weight gradients into the + weight's `main_grad` attribute instead of relying on PyTorch + autograd. The weight's `main_grad` must be set externally and + there is no guarantee that `grad` will be set or be + meaningful. + + """ + + def __init__( + self, + in_features: int, + out_features: int, + *, + bias: bool = True, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sequence_parallel: bool = False, + rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, + accumulate_into_main_grad: bool = False, + ) -> None: + + # Tensor parallel configuration + ( + tensor_parallel_mode, + tensor_parallel_group, + tensor_parallel_size, + sequence_parallel, + local_in_features, + local_out_features, + ) = BasicLinear._canonicalize_tensor_parallelism( + mode=tensor_parallel_mode, + process_group=tensor_parallel_group, + sequence_parallel=sequence_parallel, + in_features=in_features, + out_features=out_features, + ) + + # Construct basic ops + ops = [] + linear_kwargs = dict( + in_features=in_features, + out_features=out_features, + device=device, + dtype=dtype, + tensor_parallel_mode=tensor_parallel_mode, + tensor_parallel_group=tensor_parallel_group, + sequence_parallel=sequence_parallel, + rng_state_tracker_function=rng_state_tracker_function, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + bias_kwargs = dict( + size=out_features, + device=device, + dtype=dtype, + tensor_parallel=(tensor_parallel_mode is not None), + tensor_parallel_group=tensor_parallel_group, + ) + if tensor_parallel_mode == "row": + # Row TP: GEMM + bias + reduction + linear_kwargs["in_features"] = local_in_features + linear_kwargs["out_features"] = local_out_features + linear_kwargs["tensor_parallel_mode"] = None + linear_kwargs["tensor_parallel_group"] = None + linear_kwargs["sequence_parallel"] = False + bias_kwargs["size"] *= tensor_parallel_size + ops.append(BasicLinear(**linear_kwargs)) + if bias: + ops.append(Bias(**bias_kwargs)) + if sequence_parallel: + ops.append(ReduceScatter(tensor_parallel_group)) + else: + ops.append(AllReduce(tensor_parallel_group)) + else: + # Column TP or no TP: (gather + GEMM) + bias + ops.append(BasicLinear(**linear_kwargs)) + if bias: + ops.append(Bias(**bias_kwargs)) + + # Initialize base class + super().__init__(ops) + + # Register parameters + self.register_parameter("weight", self.basic_ops[0].weight) + self.register_parameter("bias", self.basic_ops[1].bias if bias else None) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py new file mode 100644 index 0000000000..3d90d07b84 --- /dev/null +++ b/transformer_engine/pytorch/ops/op.py @@ -0,0 +1,427 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Base classes for fusible operations.""" + +from __future__ import annotations +import abc +from collections.abc import Iterable +import dataclasses +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + get_default_fp8_recipe, +) +from ._common import canonicalize_device, is_float8_tensor + + +@dataclasses.dataclass +class OperationContext: + """State needed to apply an operation + + Saves state from forward pass for use in backward pass. + + """ + + # Tensors that have been saved from forward function + # Note: Available in the backward function, matching tensors from + # to_save. + saved_tensors: Optional[tuple[Optional[torch.Tensor], ...]] = None + # Tensors to save for backward function + # Note: Expected to be set in the forward function, either + # directly or with save_for_backward. + to_save: Optional[tuple[Optional[torch.Tensor], ...]] = None + + # Corresponding range in pipeline's list of saved tensors + _saved_tensors_range: Optional[tuple[int, int]] = None + + # Whether backward pass is required + _requires_grad: bool = False + + def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None: + """Register tensors to be saved for the backward function + + Expected to be called in the forward function. + + """ + self.to_save = tensors + + +class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): + """Tensor operation supported by the operation fuser""" + + @property + @abc.abstractmethod + def is_fused_op(self) -> bool: + """Whether this op is the fusion of one or more basic ops""" + + def pre_forward(self) -> None: + """Preprocessing before forward pass""" + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> torch.Tensor: + """Forward pass + + This op is either a basic op or the fusion of basic ops, so + several of this function's arguments are lists of arguments to + forward functions of corresponding basic ops. + + Called by `OperationFuser`. + + Parameters + ---------- + basic_op_ctxs: list of OperationContext + Contexts for corresponding basic operations + input_: torch.Tensor + Input tensor + basic_op_prev_ops: list of BasicOperation + Basic operations that preceed each of the corresponding + basic operations (or `None` if corresponding basic op is + first) + basic_op_next_ops: list of BasicOperation + Basic operations that follow each of the corresponding + basic operations (or `None` if corresponding basic op is + last) + basic_op_kwargs: list of dict + Keyword arguments to forward functions of corresponding + basic operations + + Returns + ------- + torch.Tensor: Output tensor. + + """ + raise NotImplementedError( + f"Forward pass is not implemented for operation ({self.__class__.__name__})" + ) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]]]: + """Backward pass + + This op is either a basic op or the fusion of basic ops, so + several of this function's arguments are lists of arguments to + backward functions of corresponding basic ops. + + Called by `OperationFuser`. + + Parameters + ---------- + basic_op_ctxs: list of OperationContext + Contexts for corresponding basic operations. + grad_output: torch.Tensor + Loss gradient w.r.t. operation output. + basic_op_prev_ops: list of BasicOperation + Basic operations that preceed each of the corresponding + basic operations (or `None` if corresponding basic op is + first) + basic_op_next_ops: list of BasicOperation + Basic operations that follow each of the corresponding + basic operations (or `None` if corresponding basic op is + last) + + Returns + ------- + torch.Tensor: + Loss gradient w.r.t. operation input + Iterable of iterable of torch.Tensor: + Loss gradients w.r.t. parameters for corresponding basic + operations + + """ + raise NotImplementedError( + f"Backward pass is not implemented for operation ({self.__class__.__name__})" + ) + + +class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): + """Single tensor operation supported by the operation fuser + + This class holds parameters and state, even if the actual forward + and backward passes are performed by a fused operation. + + """ + + def __init__(self) -> None: + super().__init__() + + # FP8 metadata objects + self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None + + @property + def is_fused_op(self) -> bool: + return False + + # pylint: disable=no-self-use + def num_fp8_scales( + self, + mode: str, # pylint: disable=unused-argument + ) -> int: + """Number of FP8 scaling factors + + Parameters + ---------- + mode: {"input", "param", "grad_output"} + Type of FP8 scaling factor + + """ + return 0 + + def _make_fp8_metas(self) -> dict[str, Optional[dict[str, Any]]]: + """Construct FP8 metadata""" + + # Shared objects for FP8 metadata + dtype = torch.float32 + device = canonicalize_device(None) + recipe = get_default_fp8_recipe() + + def _make_meta( + num_scales: int, + is_forward: bool, + ) -> Optional[dict[str, Any]]: + """Construct FP8 metadata for one tensor type""" + if num_scales == 0: + return None + key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) + meta = tex.FP8TensorMeta() + meta.scale = torch.ones(num_scales, dtype=dtype, device=device) + meta.scale_inv = torch.ones(num_scales, dtype=dtype, device=device) + meta.amax_history = torch.zeros( + (recipe.amax_history_len, num_scales), + dtype=dtype, + device=device, + ) + return { + key: meta, + "recipe": recipe, + "fp8_group": None, + } + + # Construct FP8 metadata for all tensor types + return dict( + input=_make_meta(self.num_fp8_scales("input"), True), + param=_make_meta(self.num_fp8_scales("param"), True), + grad_output=_make_meta(self.num_fp8_scales("grad_output"), False), + ) + + @classmethod + def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: + if fp8_meta is None: + return + + # Update FP8 recipe and communication group + recipe = FP8GlobalStateManager.get_fp8_recipe() + fp8_meta["recipe"] = recipe + fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + + # Adjust amax history length if needed + amax_history_len = recipe.amax_history_len + for is_forward in (True, False): + key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) + if key not in fp8_meta: + continue + meta = fp8_meta[key] + curr_len = meta.amax_history.size(0) + if curr_len == amax_history_len: + continue + with torch.no_grad(): + if curr_len > amax_history_len: + meta.amax_history = meta.amax_history[:amax_history_len].clone() + else: + meta.amax_history = torch.nn.functional.pad( + meta.amax_history, + pad=(0, 0, 0, amax_history_len - curr_len), + ) + + def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: + """FP8 metadata + + Parameters + ---------- + mode: {"input", "param", "grad_output"} + Type of FP8 scaling factor + + """ + if self._fp8_metas is None: + self._fp8_metas = self._make_fp8_metas() + return self._fp8_metas[mode] + + def pre_forward(self) -> None: + """Preprocessing before forward pass""" + + # Initialize FP8 metadata if needed + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + if fp8_enabled: + + # Construct FP8 metadata if needed + if self._fp8_metas is None: + self._fp8_metas = self._make_fp8_metas() + + # Make sure FP8 metadata matches FP8 autocast context + for fp8_meta in self._fp8_metas.values(): + self._maybe_update_fp8_meta(fp8_meta) + + # Register FP8 metadata for amax and scale update + if not FP8GlobalStateManager.fp8_graph_capturing(): + if self.num_fp8_scales("input"): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self.get_fp8_meta("input"), + ) + if self.num_fp8_scales("param"): + fp8_params = list(filter(is_float8_tensor, self.parameters())) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self.get_fp8_meta("param"), + fp8_weights=(fp8_params if fp8_params else None), + ) + if self.num_fp8_scales("grad_output"): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self.get_fp8_meta("grad_output"), + ) + + @abc.abstractmethod + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + **kwargs: Any, + ) -> torch.Tensor: + """Forward pass + + Parameters + ---------- + ctx: OperationContext + Context to coordinate between forward and backward passes + input_: torch.Tensor + Input tensor + + Returns + ------- + torch.Tensor: + Output tensor + + """ + + @abc.abstractmethod + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]: + """Backward pass + + Parameters + ---------- + ctx: OperationContext + Context to coordinate between forward and backward passes + grad_output: torch.Tensor + Loss gradient w.r.t. operation output + + Returns + ------- + torch.Tensor + Loss gradient w.r.t. operation input + Iterable of torch.Tensor: + Loss gradients w.r.t. parameters + + """ + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> torch.Tensor: + return self.op_forward( + basic_op_ctxs[0], + input_, + basic_op_prev_ops[0], + basic_op_next_ops[0], + **basic_op_kwargs[0], + ) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]]]: + grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output) + return grad_input, [grad_params] + + def forward( + self, + input: torch.Tensor, # pylint: disable=redefined-builtin + **kwargs: Any, + ) -> torch.Tensor: + """Apply operation""" + from .fuser import OperationFuser + + return OperationFuser([self], fuse_ops=False)(input, [kwargs]) + + +class FusedOperation(FusibleOperation): + """Compound tensor operation supported by the operation fuser + + If the forward or backward passes are defined, they must be + functionally equivalent to the forward/backward passes of the + corresponding basic ops. This class should hold no parameters or + other state, but should access them from the basic ops. + + Parameters + ---------- + basic_ops: iterable of FusibleOperation + Basic ops that are interchangeable with this op + + """ + + def __init__( + self, + basic_ops: Iterable[FusibleOperation], + ) -> None: + super().__init__() + + # Basic operations that comprise this fused operation + self.basic_ops: torch.nn.ModuleList = torch.nn.ModuleList(basic_ops) + if len(self.basic_ops) == 0: + raise ValueError( + "Attempted to construct a fused operation " + "without specifying its corresponding basic operations" + ) + + @property + def is_fused_op(self) -> bool: + return True + + def pre_forward(self) -> None: + """Preprocessing before forward pass""" + for op in self.basic_ops: + op.pre_forward() + + def forward( + self, + input: torch.Tensor, # pylint: disable=redefined-builtin + basic_op_kwargs: Optional[list[dict[str, Any]]] = None, + ) -> torch.Tensor: + """Apply operation""" + if basic_op_kwargs is None: + basic_op_kwargs = [{} for _ in range(len(self.basic_ops))] + from .fuser import OperationFuser + + return OperationFuser([self], fuse_ops=False)(input, basic_op_kwargs) diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py new file mode 100644 index 0000000000..95499a9e80 --- /dev/null +++ b/transformer_engine/pytorch/ops/sequential.py @@ -0,0 +1,178 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Sequential container for fusible operations.""" + +from __future__ import annotations +from collections import OrderedDict +from collections.abc import Iterable, Iterator +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops import FusibleOperation +from transformer_engine.pytorch.ops.fuser import OperationFuser + + +class Sequential(torch.nn.Module): + """Sequential container for fusible operations + + This is a drop-in replacement for `torch.nn.Sequential`, with + support for fusing `FusibleOperation`s. + + Parameters + ---------- + *args: FusibleOperation or torch.nn.Module + Neural network modules + + """ + + def __init__( + self, + *args: FusibleOperation | torch.nn.Module, + ) -> None: + super().__init__() + + # List of modules, with fusible operations grouped together + self._module_groups: Optional[list[OperationFuser | torch.nn.Module]] + self._module_groups = None + + # Add modules + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for module in args: + self.append(module) + + def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: + self._module_groups = None + super().add_module(name, module) + + def _get_keys_by_idx(self, idx: int | slice) -> list[str]: + """Get module keys corresponding to indices""" + if isinstance(idx, slice): + return list(self._modules.keys())[idx] + size = len(self._modules) + if not -size <= idx < size: + raise IndexError(f"Attempted to access index {idx}, but there are {size} entries") + if idx < 0: + idx += size + for i, key in enumerate(self._modules.keys()): + if i == idx: + return [key] + raise RuntimeError(f"Could not access index {idx}") + + def _next_key(self) -> str: + """Key for a newly added module""" + idx = 0 + for key in self._modules.keys(): + try: + key_idx = int(key) + except (ValueError, TypeError): + pass + else: + idx = max(idx, key_idx + 1) + return str(idx) + + def __getitem__( + self, + idx: slice | int, + ) -> Sequential | torch.nn.Module: + keys = self._get_keys_by_idx(idx) + if isinstance(idx, slice): + modules = OrderedDict((str(i), self._modules[key]) for i, key in enumerate(keys)) + return self.__class__(modules) + return self._modules[keys[0]] + + def __setitem__(self, idx: int, module: torch.nn.Module) -> None: + self._module_groups = None + key = self._get_keys_by_idx(idx)[0] + self._modules[key] = module + + def __delitem__(self, idx: slice | int) -> None: + self._module_groups = None + for key in self._get_keys_by_idx(idx): + del self._modules[key] + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[torch.nn.Module]: + return iter(self._modules.values()) + + def append(self, module: torch.nn.Module) -> Sequential: + """Add module at the end of the container""" + self.add_module(self._next_key(), module) + return self + + def extend(self, modules: Iterable[torch.nn.Module]) -> Sequential: + """Add modules at the end of the container""" + for module in modules: + self.append(module) + return self + + def insert(self, idx: int, module: torch.nn.Module) -> Sequential: + """Add modules at a position in the container""" + self._module_groups = None + keys = self._get_keys_by_idx(slice(idx, None)) + keys.append(self._next_key()) + for i in reversed(range(1, len(keys))): + self._modules[keys[i]] = self._modules[keys[i - 1]] + self._modules[keys[0]] = module + return self + + def pop(self, idx: slice | int) -> torch.nn.Module: + """Remove module at a position in the container""" + out = self[idx] + del self[idx] + return out + + def __iadd__(self, other: Sequential) -> Sequential: + return self.extend(other) + + def __add__(self, modules: Iterable[torch.nn.Modules]) -> Sequential: + out = self.__class__(self._modules) + out.extend(modules) + return out + + @classmethod + def _make_module_groups( + cls, + modules: Iterable[torch.nn.Module], + ) -> list[OperationFuser | torch.nn.Module]: + """Make list of modules, with fusible operations grouped together""" + module_groups = [] + fusible_ops = [] + + def maybe_add_fuser(): + nonlocal fusible_ops + if fusible_ops: + module_groups.append(OperationFuser(fusible_ops, fuse_ops=True)) + fusible_ops = [] + + for module in modules: + if isinstance(module, FusibleOperation): + fusible_ops.append(module) + else: + maybe_add_fuser() + module_groups.append(module) + maybe_add_fuser() + return module_groups + + def forward( + self, + input: torch.Tensor, # pylint: disable=redefined-builtin + ) -> torch.Tensor: + """Forward pass""" + + # Create module groups if needed + if self._module_groups is None: + self._module_groups = self._make_module_groups(self._modules.values()) + + # Forward pass for each module group + x = input + for module_group in self._module_groups: + x = module_group(x) + return x diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 9dd713ac66..5e3fa05f52 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -32,10 +32,9 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: if t is not None: if isinstance(t, Float8Tensor): t._data.data = torch.Tensor() - del t else: t.data = torch.Tensor() - del t + del t def get_device_compute_capability() -> Tuple[int, int]: