-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Determine framework automatically before ONNX export #18615
Changes from 6 commits
c7d1cb4
f45381a
0c73254
b832988
d2f78c8
ce96dee
67416f2
695c72c
cfcae03
8da5990
63198fd
8787399
6a619ff
6bd7477
d8f3804
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
import os | ||
from functools import partial, reduce | ||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union | ||
|
||
import transformers | ||
|
||
from .. import PretrainedConfig, is_tf_available, is_torch_available | ||
from ..utils import logging | ||
from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging | ||
from .config import OnnxConfig | ||
|
||
|
||
|
@@ -552,9 +553,59 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type: | |
) | ||
return task_to_automodel[task] | ||
|
||
@staticmethod | ||
def determine_framework(model: str, framework: str = None) -> str: | ||
""" | ||
Determines the framework to use for the export. | ||
|
||
The priority is in the following order: | ||
1. User input via `framework`. | ||
2. If local checkpoint is provided, use the same framework as the checkpoint. | ||
3. Available framework in environment, with priority given to PyTorch | ||
|
||
Args: | ||
model (`str`): | ||
The name of the model to export. | ||
framework (`str`, *optional*, defaults to `None`): | ||
The framework to use for the export. See above for priority if none provided. | ||
|
||
Returns: | ||
The framework to use for the export. | ||
|
||
""" | ||
if framework is not None: | ||
return framework | ||
|
||
framework_map = {"pt": "PyTorch", "tf": "TensorFlow"} | ||
exporter_map = {"pt": "torch", "tf": "tf2onnx"} | ||
|
||
if os.path.isdir(model): | ||
if os.path.isfile(os.path.join(model, WEIGHTS_NAME)): | ||
framework = "pt" | ||
elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)): | ||
framework = "tf" | ||
else: | ||
raise FileNotFoundError( | ||
"Cannot determine framework from given checkpoint location." | ||
f" There should be a {WEIGHTS_NAME} for PyTorch" | ||
f" or {TF2_WEIGHTS_NAME} for TensorFlow." | ||
) | ||
logger.info(f"Local {framework_map[framework]} model found.") | ||
else: | ||
if is_torch_available(): | ||
framework = "pt" | ||
elif is_tf_available(): | ||
framework = "tf" | ||
else: | ||
raise EnvironmentError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.") | ||
|
||
logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.") | ||
|
||
return framework | ||
|
||
@staticmethod | ||
def get_model_from_feature( | ||
feature: str, model: str, framework: str = "pt", cache_dir: str = None | ||
feature: str, model: str, framework: str = None, cache_dir: str = None | ||
) -> Union["PreTrainedModel", "TFPreTrainedModel"]: | ||
""" | ||
Attempts to retrieve a model from a model's name and the feature to be enabled. | ||
|
@@ -564,20 +615,24 @@ def get_model_from_feature( | |
The feature required. | ||
model (`str`): | ||
The name of the model to export. | ||
framework (`str`, *optional*, defaults to `"pt"`): | ||
The framework to use for the export. | ||
framework (`str`, *optional*, defaults to `None`): | ||
The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should | ||
none be provided. | ||
|
||
Returns: | ||
The instance of the model. | ||
|
||
""" | ||
framework = FeaturesManager.determine_framework(model, framework) | ||
model_class = FeaturesManager.get_model_class_for_feature(feature, framework) | ||
try: | ||
model = model_class.from_pretrained(model, cache_dir=cache_dir) | ||
except OSError: | ||
if framework == "pt": | ||
logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice idea to log these steps for the user! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! It helped me figure out the behavior, so hope it's helpful for others! |
||
model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir) | ||
else: | ||
logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.") | ||
model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir) | ||
return model | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,20 @@ | ||
import os | ||
from pathlib import Path | ||
from tempfile import NamedTemporaryFile | ||
from tempfile import NamedTemporaryFile, TemporaryDirectory | ||
from unittest import TestCase | ||
from unittest.mock import patch | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
from parameterized import parameterized | ||
from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available | ||
from transformers import ( | ||
AutoConfig, | ||
AutoModel, | ||
PreTrainedTokenizerBase, | ||
TFAutoModel, | ||
is_tf_available, | ||
is_torch_available, | ||
) | ||
from transformers.onnx import ( | ||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT, | ||
OnnxConfig, | ||
|
@@ -39,6 +46,38 @@ class OnnxUtilsTestCaseV2(TestCase): | |
Cover all the utilities involved to export ONNX models | ||
""" | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
# Create local checkpoints - one time setup | ||
test_model = "mrm8488/bert-tiny-finetuned-squadv2" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do the tests also work with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't realize that was an available variable to use! That works, done in 67416f2. |
||
cls.test_model = test_model | ||
|
||
pt_temp_dir = TemporaryDirectory() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I originally did this manually because I didn't want files to be written again between unit tests, but it has been refactored in 67416f2 to match the approach. |
||
local_pt_ckpt = pt_temp_dir.name | ||
model_pt = AutoModel.from_pretrained(test_model) | ||
model_pt.save_pretrained(local_pt_ckpt) | ||
cls.pt_temp_dir = pt_temp_dir | ||
cls.local_pt_ckpt = local_pt_ckpt | ||
|
||
tf_temp_dir = TemporaryDirectory() | ||
local_tf_ckpt = tf_temp_dir.name | ||
model_tf = TFAutoModel.from_pretrained(test_model, from_pt=True) | ||
model_tf.save_pretrained(local_tf_ckpt) | ||
cls.tf_temp_dir = tf_temp_dir | ||
cls.local_tf_ckpt = local_tf_ckpt | ||
|
||
invalid_temp_dir = TemporaryDirectory() | ||
local_invalid_ckpt = invalid_temp_dir.name | ||
cls.invalid_temp_dir = invalid_temp_dir | ||
cls.local_invalid_ckpt = local_invalid_ckpt | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
# Remove local checkpoints | ||
cls.pt_temp_dir.cleanup() | ||
cls.tf_temp_dir.cleanup() | ||
cls.invalid_temp_dir.cleanup() | ||
|
||
@require_torch | ||
@patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False) | ||
def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available): | ||
|
@@ -94,6 +133,67 @@ def test_flatten_output_collection_property(self): | |
}, | ||
) | ||
|
||
def test_determine_framework(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest promoting this to a standalone class like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion! Done in 67416f2. I broke it down to the 3 main paths, hope that is sufficient. |
||
""" | ||
Ensure the expected framework is determined. | ||
""" | ||
torch_str = "pt" | ||
tf_str = "tf" | ||
mock_framework = "mock_framework" | ||
|
||
# Framework provided - return whatever the user provides | ||
result = FeaturesManager.determine_framework(self.test_model, mock_framework) | ||
self.assertEqual(result, mock_framework) | ||
|
||
# Local checkpoint provided - return whatever the user provides | ||
result = FeaturesManager.determine_framework(self.local_pt_ckpt, mock_framework) | ||
self.assertEqual(result, mock_framework) | ||
|
||
result = FeaturesManager.determine_framework(self.local_tf_ckpt, mock_framework) | ||
self.assertEqual(result, mock_framework) | ||
|
||
# Framework not provided and local checkpoint is used | ||
result = FeaturesManager.determine_framework(self.local_pt_ckpt) | ||
self.assertEqual(result, torch_str) | ||
|
||
result = FeaturesManager.determine_framework(self.local_tf_ckpt) | ||
self.assertEqual(result, tf_str) | ||
|
||
# Framework not provided and invalid local checkpoint is used | ||
with self.assertRaises(FileNotFoundError): | ||
result = FeaturesManager.determine_framework(self.local_invalid_ckpt) | ||
|
||
# Framework not provided, hub model is used (no local checkpoint directory) | ||
# TensorFlow not in environment -> use PyTorch | ||
mock_tf_available = MagicMock(return_value=False) | ||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available): | ||
result = FeaturesManager.determine_framework(self.test_model) | ||
self.assertEqual(result, torch_str) | ||
|
||
# PyTorch not in environment -> use TensorFlow | ||
mock_torch_available = MagicMock(return_value=False) | ||
with patch("transformers.onnx.features.is_torch_available", mock_torch_available): | ||
result = FeaturesManager.determine_framework(self.test_model) | ||
self.assertEqual(result, tf_str) | ||
|
||
# Both in environment -> use PyTorch | ||
mock_tf_available = MagicMock(return_value=True) | ||
mock_torch_available = MagicMock(return_value=True) | ||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( | ||
"transformers.onnx.features.is_torch_available", mock_torch_available | ||
): | ||
result = FeaturesManager.determine_framework(self.test_model) | ||
self.assertEqual(result, torch_str) | ||
|
||
# Both not in enviornemnt -> raise error | ||
mock_tf_available = MagicMock(return_value=False) | ||
mock_torch_available = MagicMock(return_value=False) | ||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( | ||
"transformers.onnx.features.is_torch_available", mock_torch_available | ||
): | ||
with pytest.raises(EnvironmentError): | ||
result = FeaturesManager.determine_framework(self.test_model) | ||
|
||
|
||
class OnnxConfigTestCaseV2(TestCase): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love to see the logic in this function unit tested if you're up for it, e.g. under
tests/onnx/test_features.py
You could use
SMALL_MODEL_IDENTIFIER
to save a tinytorch
/tf
model to a temporary directory as follows:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a unit test in b832988, but I put it under
tests/onnx/test_onnx_v2::OnnxUtilsTestCaseV2
. I just noticed you specifiedtest_features.py
, but it does not exist yet. I can create it if you'd like, or should I leave it as is?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!! Yes, please create a new
test_features.py
file for this test (we usually maptransformers/path/to/module.py
withtests/path/to/test_module.py
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Done in 67416f2.
In 8da5990 and I registered the tests in
utils/tests_fetcher.py
because of a failure I got in CI saying that the test would not be discovered. Is this the correct way to add them?In 63198fd I added
tf
for for the pip install steps forrun_tests_onnxruntime
andrun_tests_onnxruntime_all
in.circleci/config.yml
so thatTFAutoModel
can be used. Also added-rA
flags so that the results would be more verbose. In the logs forrun_tests_onnxruntime
it can be seen that the new unit tests are tested.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for registering the test - this is indeed the way to include it :)