Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Determine framework automatically before ONNX export #18615

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ jobs:
- v0.5-torch-{{ checksum "setup.py" }}
- v0.5-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[torch,testing,sentencepiece,onnxruntime,vision,rjieba]
- run: pip install .[torch,tf,testing,sentencepiece,onnxruntime,vision,rjieba]
- save_cache:
key: v0.5-onnx-{{ checksum "setup.py" }}
paths:
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,15 @@ def main():
"--atol", type=float, default=None, help="Absolute difference tolerence when validating the model."
)
parser.add_argument(
"--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export."
"--framework",
type=str,
choices=["pt", "tf"],
default=None,
help=(
"The framework to use for the ONNX export."
" If not provided, will attempt to use the local checkpoint's original framework"
" or what is available in the environment."
),
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
Expand Down
63 changes: 59 additions & 4 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from functools import partial, reduce
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union

import transformers

from .. import PretrainedConfig, is_tf_available, is_torch_available
from ..utils import logging
from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging
from .config import OnnxConfig


Expand Down Expand Up @@ -552,9 +553,59 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type:
)
return task_to_automodel[task]

@staticmethod
def determine_framework(model: str, framework: str = None) -> str:
"""
Determines the framework to use for the export.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love to see the logic in this function unit tested if you're up for it, e.g. under tests/onnx/test_features.py

You could use SMALL_MODEL_IDENTIFIER to save a tiny torch / tf model to a temporary directory as follows:

# Ditto for the TF case
model = AutoModel.from_pretrained(SMALL_MODEL_IDENTIFIER)

with tempfile.TemporaryDirectory() as tmp_dir:
    model.save_pretrained(tmp_dir)
    framework = determine_framework(tmp_dir)
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a unit test in b832988, but I put it under tests/onnx/test_onnx_v2::OnnxUtilsTestCaseV2. I just noticed you specified test_features.py, but it does not exist yet. I can create it if you'd like, or should I leave it as is?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!! Yes, please create a new test_features.py file for this test (we usually map transformers/path/to/module.py with tests/path/to/test_module.py)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Done in 67416f2.

In 8da5990 and I registered the tests in utils/tests_fetcher.py because of a failure I got in CI saying that the test would not be discovered. Is this the correct way to add them?

In 63198fd I added tf for for the pip install steps for run_tests_onnxruntime and run_tests_onnxruntime_all in .circleci/config.yml so that TFAutoModel can be used. Also added -rA flags so that the results would be more verbose. In the logs for run_tests_onnxruntime it can be seen that the new unit tests are tested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for registering the test - this is indeed the way to include it :)


The priority is in the following order:
1. User input via `framework`.
2. If local checkpoint is provided, use the same framework as the checkpoint.
3. Available framework in environment, with priority given to PyTorch

Args:
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `None`):
The framework to use for the export. See above for priority if none provided.

Returns:
The framework to use for the export.

"""
if framework is not None:
return framework

framework_map = {"pt": "PyTorch", "tf": "TensorFlow"}
exporter_map = {"pt": "torch", "tf": "tf2onnx"}

if os.path.isdir(model):
if os.path.isfile(os.path.join(model, WEIGHTS_NAME)):
framework = "pt"
elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)):
framework = "tf"
else:
raise FileNotFoundError(
"Cannot determine framework from given checkpoint location."
f" There should be a {WEIGHTS_NAME} for PyTorch"
f" or {TF2_WEIGHTS_NAME} for TensorFlow."
)
logger.info(f"Local {framework_map[framework]} model found.")
else:
if is_torch_available():
framework = "pt"
elif is_tf_available():
framework = "tf"
else:
raise EnvironmentError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.")

logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.")

return framework

@staticmethod
def get_model_from_feature(
feature: str, model: str, framework: str = "pt", cache_dir: str = None
feature: str, model: str, framework: str = None, cache_dir: str = None
) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
"""
Attempts to retrieve a model from a model's name and the feature to be enabled.
Expand All @@ -564,20 +615,24 @@ def get_model_from_feature(
The feature required.
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.
framework (`str`, *optional*, defaults to `None`):
The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should
none be provided.

Returns:
The instance of the model.

"""
framework = FeaturesManager.determine_framework(model, framework)
model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
try:
model = model_class.from_pretrained(model, cache_dir=cache_dir)
except OSError:
if framework == "pt":
logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea to log these steps for the user!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It helped me figure out the behavior, so hope it's helpful for others!

model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir)
else:
logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.")
model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir)
return model

Expand Down
106 changes: 103 additions & 3 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from tempfile import NamedTemporaryFile, TemporaryDirectory
from unittest import TestCase
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

from parameterized import parameterized
from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
from transformers import (
AutoConfig,
AutoModel,
PreTrainedTokenizerBase,
TFAutoModel,
is_tf_available,
is_torch_available,
)
from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig,
Expand Down Expand Up @@ -39,6 +46,38 @@ class OnnxUtilsTestCaseV2(TestCase):
Cover all the utilities involved to export ONNX models
"""

@classmethod
def setUpClass(cls):
# Create local checkpoints - one time setup
test_model = "mrm8488/bert-tiny-finetuned-squadv2"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the tests also work with SMALL_MODEL_IDENTIFIER from testing_utils.py? That checkpoint is preferred since it's maintained by the HF team and less likely to vanish unexpectedly :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't realize that was an available variable to use! That works, done in 67416f2.

cls.test_model = test_model

pt_temp_dir = TemporaryDirectory()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformers codebase follows the (standard?) convention of using context managers with TemporaryDirectory() to handle temporary files / folders in our test suite. Could you please refactor to match this approach?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally did this manually because I didn't want files to be written again between unit tests, but it has been refactored in 67416f2 to match the approach.

local_pt_ckpt = pt_temp_dir.name
model_pt = AutoModel.from_pretrained(test_model)
model_pt.save_pretrained(local_pt_ckpt)
cls.pt_temp_dir = pt_temp_dir
cls.local_pt_ckpt = local_pt_ckpt

tf_temp_dir = TemporaryDirectory()
local_tf_ckpt = tf_temp_dir.name
model_tf = TFAutoModel.from_pretrained(test_model, from_pt=True)
model_tf.save_pretrained(local_tf_ckpt)
cls.tf_temp_dir = tf_temp_dir
cls.local_tf_ckpt = local_tf_ckpt

invalid_temp_dir = TemporaryDirectory()
local_invalid_ckpt = invalid_temp_dir.name
cls.invalid_temp_dir = invalid_temp_dir
cls.local_invalid_ckpt = local_invalid_ckpt

@classmethod
def tearDownClass(cls):
# Remove local checkpoints
cls.pt_temp_dir.cleanup()
cls.tf_temp_dir.cleanup()
cls.invalid_temp_dir.cleanup()

@require_torch
@patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False)
def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available):
Expand Down Expand Up @@ -94,6 +133,67 @@ def test_flatten_output_collection_property(self):
},
)

def test_determine_framework(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest promoting this to a standalone class like DetermineFrameWorkTest and then treat each case with a dedicated function like test_framework_provided(self) etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Done in 67416f2. I broke it down to the 3 main paths, hope that is sufficient.

"""
Ensure the expected framework is determined.
"""
torch_str = "pt"
tf_str = "tf"
mock_framework = "mock_framework"

# Framework provided - return whatever the user provides
result = FeaturesManager.determine_framework(self.test_model, mock_framework)
self.assertEqual(result, mock_framework)

# Local checkpoint provided - return whatever the user provides
result = FeaturesManager.determine_framework(self.local_pt_ckpt, mock_framework)
self.assertEqual(result, mock_framework)

result = FeaturesManager.determine_framework(self.local_tf_ckpt, mock_framework)
self.assertEqual(result, mock_framework)

# Framework not provided and local checkpoint is used
result = FeaturesManager.determine_framework(self.local_pt_ckpt)
self.assertEqual(result, torch_str)

result = FeaturesManager.determine_framework(self.local_tf_ckpt)
self.assertEqual(result, tf_str)

# Framework not provided and invalid local checkpoint is used
with self.assertRaises(FileNotFoundError):
result = FeaturesManager.determine_framework(self.local_invalid_ckpt)

# Framework not provided, hub model is used (no local checkpoint directory)
# TensorFlow not in environment -> use PyTorch
mock_tf_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, torch_str)

# PyTorch not in environment -> use TensorFlow
mock_torch_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_torch_available", mock_torch_available):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, tf_str)

# Both in environment -> use PyTorch
mock_tf_available = MagicMock(return_value=True)
mock_torch_available = MagicMock(return_value=True)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
"transformers.onnx.features.is_torch_available", mock_torch_available
):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, torch_str)

# Both not in enviornemnt -> raise error
mock_tf_available = MagicMock(return_value=False)
mock_torch_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
"transformers.onnx.features.is_torch_available", mock_torch_available
):
with pytest.raises(EnvironmentError):
result = FeaturesManager.determine_framework(self.test_model)


class OnnxConfigTestCaseV2(TestCase):
"""
Expand Down