Skip to content

Commit

Permalink
Skip Triton import for AMD (microsoft#5110)
Browse files Browse the repository at this point in the history
When testing DeepSpeed inference on an `AMD Instinct MI250X/MI250` GPU,
the `pytorch-triton-rocm` module would break the `torch.cuda` device
API. To address this, importing `triton` is skipped when the GPU is
determined to be `AMD`.

This change allows DeepSpeed to be executed on an AMD GPU w/o kernel
injection in the DeepSpeedExamples [text-generation
example](https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/huggingface/text-generation)
using the following command:
```bash
deepspeed --num_gpus 1 inference-test.py --model facebook/opt-125m
```

TODO: Root-cause the interaction between `pytorch-triton-rocm` and
DeepSpeed to understand why this is causing the `torch.cuda` device API
to break.
  • Loading branch information
lekurile authored and rraminen committed May 9, 2024
1 parent 57d49f1 commit fdf81ff
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
from torch.optim.lr_scheduler import _LRScheduler
from packaging import version as pkg_version

try:
import triton # noqa: F401 # type: ignore
HAS_TRITON = True
except ImportError:
# Skip Triton import for AMD due to pytorch-triton-rocm module breaking device API in DeepSpeed
if not (hasattr(torch.version, 'hip') and torch.version.hip is not None):
try:
import triton # noqa: F401 # type: ignore
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
else:
HAS_TRITON = False

from . import ops
Expand Down

0 comments on commit fdf81ff

Please sign in to comment.