Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for onnx:CumSum to improve Onnx results #66

Closed
SirMomster opened this issue Jan 6, 2023 · 7 comments
Closed

Support for onnx:CumSum to improve Onnx results #66

SirMomster opened this issue Jan 6, 2023 · 7 comments

Comments

@SirMomster
Copy link

This is a follow-up on #12.

We noticed a big difference in detected results between the onnx and torch model inference. And believe this might have to do with the fact that we add refine_iters=0 thus skipping refinement iterations.

According to the following remark by @baudm #12 (comment)
The issue is with CumSum method in Onnx, which according to the error:
InferenceError: [ShapeInferenceError] (op_type:CumSum, node name: CumSum_2527): x typestr: T, has unsupported type: tensor(bool)

Uses an unsupported type tensor(bool) which is correct as can be found here: https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum

Would it be possible to change the type and thus altering tgt_padding_mask = ((tgt_in == self.eos_id).cumsum(-1) > 0) to return a supported type?

To my limited understanding, this would increase accuracy of the onnx exported model? As besides this part both models do more or less the same?

@baudm
Copy link
Owner

baudm commented Jan 14, 2023

Oh, I didn't notice that the error was about the type of the argument to cumsum. Thanks for pointing it out. The straightforward fix is to typecast (tgt_in == self.eos_id) to int prior to calling cumsum(-1). The cumsum() here essentially propagates the EOF count to the right, allowing us to mask the EOF token and all tokens to its right.

I'll try this later once I get access to my laptop.

@jturner116
Copy link

@SirMomster Have you tested this fix yet? I'm still experiencing the inconsistent results at the end of tokens like was seen in #12.

@SirMomster
Copy link
Author

@jturner116 I will be trying it out asap, will post an update here.

@SirMomster
Copy link
Author

@baudm was the torch hub version automatically updated and do they include your change or do I need to build from source?

@baudm
Copy link
Owner

baudm commented Jan 19, 2023

@baudm was the torch hub version automatically updated and do they include your change or do I need to build from source?

Torch Hub caches the repo upon first use. Use force_reload=True to update Torch Hub's cache:

parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, force_reload=True).eval()

@jturner116
Copy link

I am still getting the following warnings on the conversion with the fresh torch.hub download
image

This is the tolerance failure
image

I'm using torch==1.13.1+cu117
Onnx==1.12.0
Onnxruntime==1.13.1

Curious to see if this resembles your experience @SirMomster :)

@jturner116
Copy link

jturner116 commented Jan 19, 2023

if testing and (tgt_in == self.eos_id).any(dim=-1).all():

If I use a sufficiently long word as the dummy input (I actually used Verbandsteffe), the AssertionErrors go away for all of the images in DemoImages. Maybe if you use an image with a word with the maximum character count, the outputs will be stable?

EDIT: In my additional testing, it does seem that inference shorter than the dummy input used is stable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants