Skip to content

Commit

Permalink
Do not import transformer_engine on import (#3056)
Browse files Browse the repository at this point in the history
* Do not import `transformer_engine` on import

* fix message

* add test

* Update test_imports.py

* resolve comment 1/2

* resolve comment 1.5/2

* lint

* more lint

* Update tests/test_imports.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* fmt

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
  • Loading branch information
oraluben and muellerzr authored Aug 28, 2024
1 parent 939ce40 commit 3fcc946
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
12 changes: 9 additions & 3 deletions src/accelerate/utils/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from .operations import GatheredParameters


if is_fp8_available():
import transformer_engine.pytorch as te
# Do not import `transformer_engine` at package level to avoid potential issues


def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
Expand All @@ -30,6 +29,8 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
"""
if not is_fp8_available():
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
import transformer_engine.pytorch as te

for name, module in model.named_children():
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
Expand Down Expand Up @@ -87,6 +88,8 @@ def has_transformer_engine_layers(model):
"""
if not is_fp8_available():
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
import transformer_engine.pytorch as te

for m in model.modules():
if isinstance(m, (te.LayerNorm, te.Linear, te.TransformerLayer)):
return True
Expand All @@ -98,6 +101,8 @@ def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
"""
if not is_fp8_available():
raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
from transformer_engine.pytorch import fp8_autocast

def forward(self, *args, **kwargs):
Expand All @@ -115,7 +120,8 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
"""
Applies FP8 context manager to the model's forward method
"""
# Import here to keep base imports fast
if not is_fp8_available():
raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
import transformer_engine.common.recipe as te_recipe

kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
Expand Down
19 changes: 18 additions & 1 deletion tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys

from accelerate.test_utils import require_transformer_engine
from accelerate.test_utils.testing import TempDirTestCase, require_import_timer
from accelerate.utils import is_import_timer_available

Expand All @@ -31,7 +33,7 @@ def convert_list_to_string(data):


def run_import_time(command: str):
output = subprocess.run(["python3", "-X", "importtime", "-c", command], capture_output=True, text=True)
output = subprocess.run([sys.executable, "-X", "importtime", "-c", command], capture_output=True, text=True)
return output.stderr


Expand Down Expand Up @@ -81,3 +83,18 @@ def test_cli_import(self):
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7)
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
self.assertLess(pct_more, 20, err_msg)


@require_transformer_engine
class LazyImportTester(TempDirTestCase):
"""
Test suite which checks if specific packages are lazy-loaded.
Eager-import will trigger circular import in some case,
e.g. in huggingface/accelerate#3056.
"""

def test_te_import(self):
output = run_import_time("import accelerate, accelerate.utils.transformer_engine")

self.assertFalse(" transformer_engine" in output, "`transformer_engine` should not be imported on import")

0 comments on commit 3fcc946

Please sign in to comment.