-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable ONNX export of 5B GPT trained with TE FP8 modules #6458
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
e57472e
add GPT FP8 ONNX export support
asfiyab-nvidia f64d546
changes
asfiyab-nvidia 0f5b9b7
Conform to Python style guidelines
asfiyab-nvidia 2c5d5fb
refactor to avoid typecasting bf16 string
asfiyab-nvidia 162cb6c
fix attribute error in export_utils
asfiyab-nvidia 29917ee
set constant_folding to False by default
asfiyab-nvidia bb16f60
refactor exportable wrapper into model class definition
asfiyab-nvidia 580c026
remove conditional replacement of modules
asfiyab-nvidia 2357e77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] eaeafd7
set fp8_recipe to None by default
asfiyab-nvidia 26470b9
address all comments
asfiyab-nvidia 76e3d8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7948f35
typecast precision check for fp16
asfiyab-nvidia a2f5bce
Merge branch 'dev-gpt-fp7-export' of github.com:asfiyab-nvidia/NeMo i…
asfiyab-nvidia 4af9195
Merge branch 'main' into dev-gpt-fp8-export
borisfom 4a83c47
rename export script
asfiyab-nvidia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
25 changes: 25 additions & 0 deletions
25
examples/nlp/language_modeling/conf/megatron_gpt_export.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
trainer: | ||
devices: 1 | ||
num_nodes: 1 | ||
accelerator: gpu | ||
logger: False # logger provided by exp_manager | ||
precision: bf16 # 16, 32, or bf16 | ||
|
||
model_type: gpt | ||
tensor_model_parallel_size: 1 | ||
pipeline_model_parallel_size: 1 | ||
pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others) | ||
gpt_model_file: null # GPT nemo file path | ||
onnx_model_file: null # ONNX file path | ||
checkpoint_dir: null # Checkpoint directory | ||
checkpoint_name: null # Checkpoint name | ||
hparams_file: null # hparams filepath | ||
|
||
export_options: | ||
runtime_check: False | ||
verbose: False | ||
onnx_opset: 17 | ||
do_constant_folding: True | ||
cache_support: False | ||
device: 'cuda' | ||
check_tolerance: 0.01 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
from omegaconf import OmegaConf, open_dict | ||
from pytorch_lightning import Trainer | ||
|
||
from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel | ||
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel | ||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel | ||
from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel | ||
from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model | ||
from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel | ||
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel | ||
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector | ||
from nemo.core import ModelPT | ||
from nemo.core.config import hydra_runner | ||
from nemo.utils import logging | ||
from nemo.utils.app_state import AppState | ||
from nemo.utils.model_utils import inject_model_parallel_rank | ||
|
||
|
||
def get_model_class(cfg): | ||
if cfg.model_type == 'gpt': | ||
return MegatronGPTModel | ||
elif cfg.model_type == 'bert': | ||
return MegatronBertModel | ||
elif cfg.model_type == 't5': | ||
return MegatronT5Model | ||
elif cfg.model_type == 'bart': | ||
return MegatronBARTModel | ||
elif cfg.model_type == 'nmt': | ||
return MegatronNMTModel | ||
elif cfg.model_type == 'retro': | ||
return MegatronRetrievalModel | ||
else: | ||
raise ValueError("Invalid Model Type") | ||
|
||
|
||
@hydra_runner(config_path="conf", config_name="megatron_gpt_export") | ||
def nemo_export(cfg): | ||
"""Convert a nemo model into .onnx ONNX format.""" | ||
nemo_in = None | ||
if cfg.gpt_model_file: | ||
nemo_in = cfg.gpt_model_file | ||
elif cfg.checkpoint_dir: | ||
nemo_in = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name) | ||
assert nemo_in is not None, "NeMo model not provided. Please provide the path to the .nemo or .ckpt file" | ||
|
||
onnx_out = cfg.onnx_model_file | ||
|
||
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) | ||
assert ( | ||
cfg.trainer.devices * cfg.trainer.num_nodes | ||
== cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size | ||
), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" | ||
|
||
logging.info("Restoring NeMo model from '{}'".format(nemo_in)) | ||
try: | ||
if cfg.gpt_model_file: | ||
save_restore_connector = NLPSaveRestoreConnector() | ||
if os.path.isdir(cfg.gpt_model_file): | ||
save_restore_connector.model_extracted_dir = cfg.gpt_model_file | ||
|
||
pretrained_cfg = ModelPT.restore_from( | ||
restore_path=cfg.gpt_model_file, | ||
trainer=trainer, | ||
return_config=True, | ||
save_restore_connector=save_restore_connector, | ||
) | ||
OmegaConf.set_struct(pretrained_cfg, True) | ||
with open_dict(pretrained_cfg): | ||
pretrained_cfg.sequence_parallel = False | ||
pretrained_cfg.activations_checkpoint_granularity = None | ||
pretrained_cfg.activations_checkpoint_method = None | ||
pretrained_cfg.precision = trainer.precision | ||
if trainer.precision == "16": | ||
pretrained_cfg.megatron_amp_O2 = False | ||
model = ModelPT.restore_from( | ||
restore_path=cfg.gpt_model_file, | ||
trainer=trainer, | ||
override_config_path=pretrained_cfg, | ||
save_restore_connector=save_restore_connector, | ||
) | ||
elif cfg.checkpoint_dir: | ||
app_state = AppState() | ||
if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1: | ||
app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size | ||
app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size | ||
app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size | ||
( | ||
app_state.tensor_model_parallel_rank, | ||
app_state.pipeline_model_parallel_rank, | ||
app_state.model_parallel_size, | ||
app_state.data_parallel_size, | ||
app_state.pipeline_model_parallel_split_rank, | ||
app_state.virtual_pipeline_model_parallel_rank, | ||
) = fake_initialize_model_parallel( | ||
world_size=app_state.model_parallel_size, | ||
rank=trainer.global_rank, | ||
tensor_model_parallel_size_=cfg.tensor_model_parallel_size, | ||
pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, | ||
pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, | ||
) | ||
checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) | ||
model_cls = get_model_class(cfg) | ||
model = model_cls.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) | ||
else: | ||
raise ValueError("need at least a nemo file or checkpoint dir") | ||
except Exception as e: | ||
logging.error( | ||
"Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format( | ||
nemo_in | ||
) | ||
) | ||
raise e | ||
|
||
logging.info("Model {} restored from '{}'".format(model.__class__.__name__, nemo_in)) | ||
|
||
# Export | ||
check_trace = cfg.export_options.runtime_check | ||
|
||
try: | ||
model.to(device=cfg.export_options.device).freeze() | ||
model.eval() | ||
model.export( | ||
onnx_out, | ||
onnx_opset_version=cfg.export_options.onnx_opset, | ||
do_constant_folding=cfg.export_options.do_constant_folding, | ||
dynamic_axes={ | ||
'input_ids': {0: "sequence", 1: "batch"}, | ||
'position_ids': {0: "sequence", 1: "batch"}, | ||
'logits': {0: "sequence", 1: "batch"}, | ||
}, | ||
check_trace=check_trace, | ||
check_tolerance=cfg.export_options.check_tolerance, | ||
verbose=cfg.export_options.verbose, | ||
) | ||
except Exception as e: | ||
logging.error( | ||
"Export failed. Please make sure your NeMo model class ({}) has working export() and that you have the latest NeMo package installed with [all] dependencies.".format( | ||
model.__class__ | ||
) | ||
) | ||
raise e | ||
|
||
|
||
if __name__ == '__main__': | ||
nemo_export() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check notice
Code scanning / CodeQL
Explicit returns mixed with implicit (fall through) returns