-
Notifications
You must be signed in to change notification settings - Fork 468
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add TransformerEngine to PT 2.0 training images (#3315)
- Loading branch information
Showing
6 changed files
with
142 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 12 additions & 0 deletions
12
test/dlc_tests/container_tests/bin/transformerengine/testPTTransformerEngine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/bin/bash | ||
|
||
set -ex | ||
|
||
git clone --branch release_v0.12 https://github.com/NVIDIA/TransformerEngine.git | ||
cd TransformerEngine/tests/pytorch | ||
|
||
pip install pytest==6.2.5 onnxruntime==1.13.1 onnx | ||
pytest -v -s test_sanity.py | ||
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s test_numerics.py | ||
NVTE_TORCH_COMPILE=0 pytest -v -s test_onnx_export.py | ||
pytest -v -s test_jit.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
import test.test_utils.ec2 as ec2_utils | ||
from test.test_utils import CONTAINER_TESTS_PREFIX, is_pr_context, is_efa_dedicated | ||
from test.test_utils.ec2 import get_efa_ec2_instance_type, filter_efa_instance_type | ||
|
||
PT_TE_TESTS_CMD = os.path.join( | ||
CONTAINER_TESTS_PREFIX, "transformerengine", "testPTTransformerEngine" | ||
) | ||
|
||
|
||
EC2_EFA_GPU_INSTANCE_TYPE_AND_REGION = get_efa_ec2_instance_type( | ||
default="p4d.24xlarge", | ||
filter_function=filter_efa_instance_type, | ||
) | ||
|
||
|
||
@pytest.mark.processor("gpu") | ||
@pytest.mark.model("N/A") | ||
@pytest.mark.integration("transformerengine") | ||
@pytest.mark.usefixtures("sagemaker") | ||
@pytest.mark.allow_p4de_use | ||
@pytest.mark.parametrize("ec2_instance_type,region", EC2_EFA_GPU_INSTANCE_TYPE_AND_REGION) | ||
@pytest.mark.skipif( | ||
is_pr_context() and not is_efa_dedicated(), | ||
reason="Skip heavy instance test in PR context unless explicitly enabled", | ||
) | ||
def test_pytorch_transformerengine( | ||
pytorch_training, ec2_connection, region, ec2_instance_type, gpu_only, py3_only | ||
): | ||
ec2_utils.execute_ec2_training_test(ec2_connection, pytorch_training, PT_TE_TESTS_CMD) |