-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Determine framework automatically before ONNX export #18615
Determine framework automatically before ONNX export #18615
Conversation
The documentation is not available anymore as the PR was closed or merged. |
…port-driver-auto-framework
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for greatly improving the framework selection in the ONNX exporter @rachthree (also, welcome as a first time contributor 🥳)!
Overall, the logic looks great to me and I'd really like to see a unit test of the determine_framework
function. This would give us some confidence that any future changes on the framework selection side won't accidentally break the desired behaviour.
Regarding the failing unit tests, these will be fixed by:
so we can rebase your branch on main
once they're approved / merged (should be soon)
src/transformers/onnx/features.py
Outdated
return framework | ||
|
||
framework_map = {"pt": "PyTorch", "tf": "TensorFlow"} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
turbo nit:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Done in b832988
src/transformers/onnx/features.py
Outdated
f" or {TF2_WEIGHTS_NAME} for TensorFlow." | ||
) | ||
logger.info(f"Local {framework_map[framework]} model found.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Done in b832988
model_class = FeaturesManager.get_model_class_for_feature(feature, framework) | ||
try: | ||
model = model_class.from_pretrained(model, cache_dir=cache_dir) | ||
except OSError: | ||
if framework == "pt": | ||
logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice idea to log these steps for the user!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! It helped me figure out the behavior, so hope it's helpful for others!
@staticmethod | ||
def determine_framework(model: str, framework: str) -> str: | ||
""" | ||
Determines the framework to use for the export. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love to see the logic in this function unit tested if you're up for it, e.g. under tests/onnx/test_features.py
You could use SMALL_MODEL_IDENTIFIER
to save a tiny torch
/ tf
model to a temporary directory as follows:
# Ditto for the TF case
model = AutoModel.from_pretrained(SMALL_MODEL_IDENTIFIER)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
framework = determine_framework(tmp_dir)
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a unit test in b832988, but I put it under tests/onnx/test_onnx_v2::OnnxUtilsTestCaseV2
. I just noticed you specified test_features.py
, but it does not exist yet. I can create it if you'd like, or should I leave it as is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!! Yes, please create a new test_features.py
file for this test (we usually map transformers/path/to/module.py
with tests/path/to/test_module.py
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Done in 67416f2.
In 8da5990 and I registered the tests in utils/tests_fetcher.py
because of a failure I got in CI saying that the test would not be discovered. Is this the correct way to add them?
In 63198fd I added tf
for for the pip install steps for run_tests_onnxruntime
and run_tests_onnxruntime_all
in .circleci/config.yml
so that TFAutoModel
can be used. Also added -rA
flags so that the results would be more verbose. In the logs for run_tests_onnxruntime
it can be seen that the new unit tests are tested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for registering the test - this is indeed the way to include it :)
src/transformers/onnx/features.py
Outdated
@@ -552,9 +553,61 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type: | |||
) | |||
return task_to_automodel[task] | |||
|
|||
@staticmethod | |||
def determine_framework(model: str, framework: str) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
def determine_framework(model: str, framework: str) -> str: | |
def determine_framework(model: str, framework: str = None) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch! Done in b832988
…ree/transformers into onnx-export-driver-auto-framework
Thank you for the review and welcoming me! I'm excited to contribute, especially since this is my first PR in the open source community :) Glad to see the 2 PRs will fix those unit tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for adding an extensive test suite for this functionality @rachthree ! I think we just need to do a little bit of refactoring and this should be good to go 🔥 !
tests/onnx/test_onnx_v2.py
Outdated
@@ -94,6 +133,67 @@ def test_flatten_output_collection_property(self): | |||
}, | |||
) | |||
|
|||
def test_determine_framework(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest promoting this to a standalone class like DetermineFrameWorkTest
and then treat each case with a dedicated function like test_framework_provided(self)
etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! Done in 67416f2. I broke it down to the 3 main paths, hope that is sufficient.
tests/onnx/test_onnx_v2.py
Outdated
test_model = "mrm8488/bert-tiny-finetuned-squadv2" | ||
cls.test_model = test_model | ||
|
||
pt_temp_dir = TemporaryDirectory() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transformers
codebase follows the (standard?) convention of using context managers with TemporaryDirectory()
to handle temporary files / folders in our test suite. Could you please refactor to match this approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally did this manually because I didn't want files to be written again between unit tests, but it has been refactored in 67416f2 to match the approach.
@staticmethod | ||
def determine_framework(model: str, framework: str) -> str: | ||
""" | ||
Determines the framework to use for the export. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!! Yes, please create a new test_features.py
file for this test (we usually map transformers/path/to/module.py
with tests/path/to/test_module.py
)
tests/onnx/test_onnx_v2.py
Outdated
@classmethod | ||
def setUpClass(cls): | ||
# Create local checkpoints - one time setup | ||
test_model = "mrm8488/bert-tiny-finetuned-squadv2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the tests also work with SMALL_MODEL_IDENTIFIER
from testing_utils.py
? That checkpoint is preferred since it's maintained by the HF team and less likely to vanish unexpectedly :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't realize that was an available variable to use! That works, done in 67416f2.
…ved to tests/onnx/test_features.py
…port-driver-auto-framework
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for iterating on this @rachthree! The test coverage looks solid and I've just left some final nits - apart from that it LGTM 🔥
Gently pinging @patrickvonplaten or @LysandreJik for final approval
.circleci/config.yml
Outdated
@@ -888,7 +888,7 @@ jobs: | |||
path: ~/transformers/test_preparation.txt | |||
- run: | | |||
if [ -f test_list.txt ]; then | |||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s --make-reports=tests_onnx $(cat test_list.txt) -k onnx | tee tests_output.txt | |||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -rA -s --make-reports=tests_onnx $(cat test_list.txt) -k onnx | tee tests_output.txt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this change will interfere with our internal Slack reporting. Gently pinging @ydshieh for his input 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using --make-reports
, there is no more need to use -rA
.
See here
I believe -rA
will display the very detailed results on the terminal, which we would like to avoid. Instead, the summary is saved to files when we use --make-reports
(this is not a pytest flag, but our own custom flag).
see pytest doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydshieh I added -rA
because even though the test was included in tests_fetcher.py
, require_tf
actually skipped the test since originally TensorFlow was not installed for run_tests_onnxruntime
. I wouldn't have known that I needed to add tf
for run_tests_onnxruntime
without using -rA
and checking that the tests I expected to run did run. The artifacts did not say exactly which tests were run, but what would be collected. Does --make-reports
generate other reports that are not in CircleCI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rachthree In ARTIFACTS
tab of this run_tests_onnxruntime run page, you could find ~/transformers/tests_output.txt
. Click it, at the end, you will see something like
...
...
SKIPPED [1] tests/onnx/test_onnx.py:86: test is slow
SKIPPED [1] tests/onnx/test_onnx.py:75: test is slow
SKIPPED [400] ../.pyenv/versions/3.7.12/lib/python3.7/unittest/case.py:93: test is slow
SKIPPED [2] tests/onnx/test_onnx_v2.py:370: test is slow
================ 45 passed, 444 skipped, 45 warnings in 46.63s =================
In previous runs where TF was not installed, we should be able to see those ONNX tests were skipped for TF.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydshieh I went back to the CI run for commit 8da5990 which didn't have TF and -rA
, and even though ~/transformers/tests_output.txt didn't show the breakdown, ~/transformers/reports/tests_onnx/summary_short.txt did. Somehow I missed this when going through the artifacts, my apologies! -rA
has been removed in d8f3804. Thanks :)
There are other runs in .circleci/config.yaml
that still have -rA
from before... will the flag be removed from those runs in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been a long time since I wrote that code, does it still work with -rA
? If so, there is no reason, but please remember that --make-reports
is a hack, so it's quite possible that that flag was needed.
Please try w/o it and if it still produces the same output, then we may choose to remove it.
It's also possible that it is there if we wanted both - the normal log and the --make-reports
individual logs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00 For run_tests_onnxruntime
, it still works. ~/transformers/reports/tests_onnx/summary_short.txt
always has the breakdown, so -rA
is not needed as long as people know which artifact to look for. However, the output changes for the terminal output in CircleCI and ~/transformers/tests_outputs.txt
as the short test summary info doesn't show without it.
As for the rest of the tests in the CircleCI config, I think removing -rA
would be outside the scope of this PR, but I can try it here or another branch if you'd like. Not having consistency confused me when I looked at the logs of the other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rachthree I think we can merge this PR first, with -rA
if you prefer. As other tests also have this flag, I won't worry about having it.
We can check if there is a need to remove this flag in a separate PR, but this is not a priority.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do what's the best. I was just mentioning that if I remembered correctly some of those flags were still needed for --make-reports
to work as intended.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR's been merged without the -rA
flag, maintaining previous logging behavior. Thank you all for your review 😊
@staticmethod | ||
def determine_framework(model: str, framework: str) -> str: | ||
""" | ||
Determines the framework to use for the export. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for registering the test - this is indeed the way to include it :)
tests/onnx/test_features.py
Outdated
|
||
For the functionality to execute, local checkpoints are provided but framework is not. | ||
""" | ||
torch_str = "pt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: to save a bit of code duplication + ease of maintenance, we could move these variables to setUp
and then access them with self.torch_str
We also tend to prefer putting the type on variable names, so an alternative would be to use something like framework_pt = "pt"
and framework_tf = "tf"
instead of torch_str
etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, thanks! Done in 6bd7477.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, looks good to me. PyTorch has priority in the hierarchy when detecting the framework, so that's backwards compatible
What does this PR do?
Determines whether to use
torch
ortf2onnx
as the ONNX exporter automatically with the following priority:framework
/--framework
.Fixes issue #18495 where PyTorch was still attempted for a local TF checkpoint even though it did not exist in the environment. This also avoids requiring users to use
--framework=tf
when using the ONNX export driver script.Misc:
tf
to pip install forrun_tests_onnxruntime
andrun_tests_onnxruntime_all
in CI.Tests
python -m transformers.onnx
driver with and without--framework
on local checkpoints and hub. Tested in containerized environments that had only PyTorch, only TensorFlow, or both.RUN_SLOW=true pytest tests/onnx
* Overall, tests passed w.r.tmain
since they share the same failing tests:* Wrote up #18614 for theFixed by #18336TypeError: 'module' object is not callable
errors.* As for theFixed by #18587AutoModel
error, https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/modeling_auto.py#L363 says not to add new models, so is this failure acceptable?Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@LysandreJik and others who may be interested :)