-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Export of Openai Whisper [Batched decoding ver] #18815
Support Export of Openai Whisper [Batched decoding ver] #18815
Conversation
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
@@ -0,0 +1,87 @@ | |||
# ------------------------------------------------------------------------- |
Check warning
Code scanning / lintrunner
BLACK-ISORT/format Warning
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
@@ -0,0 +1,486 @@ | |||
# ------------------------------------------------------------------------- |
Check warning
Code scanning / lintrunner
BLACK-ISORT/format Warning
@@ -312,8 +378,15 @@ | |||
"tensor(uint8)": np.uint8, | |||
} | |||
|
|||
# Generate prompts | |||
prompt_text = "Christians" | |||
prompt_ids = processor.get_prompt_ids(prompt_text) |
Check warning
Code scanning / lintrunner
RUFF/F841 Warning
See https://docs.astral.sh/ruff/rules/unused-variable
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
Fixed
Show fixed
Hide fixed
self.attention_fusion = FusionBartAttentionOpenai( | ||
self, | ||
self.hidden_size, | ||
self.num_heads, | ||
self.attention_mask | ||
) |
Check warning
Code scanning / CodeQL
Overwriting attribute in super-class or sub-class Warning
BertOnnxModel
self.attention_mask | ||
) | ||
else: | ||
self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) |
Check warning
Code scanning / CodeQL
Overwriting attribute in super-class or sub-class Warning
BertOnnxModel
@@ -343,21 +420,24 @@ | |||
diff = pt_outputs - ort_outputs | |||
max_diff = max(diff.min(), diff.max(), key=abs) | |||
|
|||
if max_diff > 0: | |||
if True: |
Check warning
Code scanning / CodeQL
Constant in conditional expression or statement Warning
@@ -312,8 +378,15 @@ | |||
"tensor(uint8)": np.uint8, | |||
} | |||
|
|||
# Generate prompts | |||
prompt_text = "Christians" | |||
prompt_ids = processor.get_prompt_ids(prompt_text) |
Check notice
Code scanning / CodeQL
Unused local variable Note
1c2d388
to
12f0ff1
Compare
@@ -8,6 +8,7 @@ | |||
import logging | |||
import os | |||
import tempfile | |||
import copy |
Check warning
Code scanning / lintrunner
RUFF/F811 Warning
See https://docs.astral.sh/ruff/rules/redefined-while-unused
@@ -6,6 +6,7 @@ | |||
|
|||
import logging | |||
import os | |||
import io |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
@@ -19,6 +20,10 @@ | |||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper | |||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper | |||
|
|||
from whisper.model import Whisper, ModelDimensions |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
@@ -19,6 +20,10 @@ | |||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper | |||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper | |||
|
|||
from whisper.model import Whisper, ModelDimensions |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
@@ -19,6 +20,10 @@ | |||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper | |||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper | |||
|
|||
from whisper.model import Whisper, ModelDimensions | |||
from whisper import _MODELS, _ALIGNMENT_HEADS |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
@@ -19,6 +20,10 @@ | |||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper | |||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper | |||
|
|||
from whisper.model import Whisper, ModelDimensions | |||
from whisper import _MODELS, _ALIGNMENT_HEADS |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
@@ -19,6 +20,10 @@ | |||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper | |||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper | |||
|
|||
from whisper.model import Whisper, ModelDimensions | |||
from whisper import _MODELS, _ALIGNMENT_HEADS | |||
from whisper import _download |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
@@ -8,6 +8,7 @@ | |||
import logging | |||
import os | |||
import tempfile | |||
import copy |
Check notice
Code scanning / CodeQL
Module is imported more than once Note
on line 7
@@ -6,6 +6,7 @@ | |||
|
|||
import logging | |||
import os | |||
import io |
Check notice
Code scanning / CodeQL
Unused import Note
@@ -19,6 +20,10 @@ | |||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper | |||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper | |||
|
|||
from whisper.model import Whisper, ModelDimensions |
Check notice
Code scanning / CodeQL
Unused import Note
Import of 'ModelDimensions' is not used.
@@ -19,6 +20,10 @@ | |||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper | |||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper | |||
|
|||
from whisper.model import Whisper, ModelDimensions | |||
from whisper import _MODELS, _ALIGNMENT_HEADS |
Check notice
Code scanning / CodeQL
Unused import Note
Import of '_ALIGNMENT_HEADS' is not used.
@@ -19,6 +20,10 @@ | |||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper | |||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper | |||
|
|||
from whisper.model import Whisper, ModelDimensions | |||
from whisper import _MODELS, _ALIGNMENT_HEADS | |||
from whisper import _download |
Check notice
Code scanning / CodeQL
Unused import Note
onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
Outdated
Show resolved
Hide resolved
@@ -50,6 +50,8 @@ ONNX_OPERATOR_KERNEL_EX( | |||
.InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU | |||
.InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU | |||
.InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU | |||
.InputMemoryType(OrtMemTypeCPUInput, 15) // 'left_pad_mask' needs to be on CPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update new input descriptions
@@ -374,8 +381,15 @@ def verify_onnx( | |||
"tensor(uint8)": np.uint8, | |||
} | |||
|
|||
# Generate prompts | |||
prompt_text = "Christians" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean up testing module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try running lintrunner -a
to fix the lint issues.
@@ -124,7 +125,12 @@ class BartOnnxModel(BertOnnxModel): | |||
def __init__(self, model, num_heads, hidden_size, model_impl="hf"): | |||
super().__init__(model, num_heads, hidden_size) | |||
self.attention_mask = AttentionMask(self) | |||
self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) | |||
if model_impl == "openai": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we do not need model_impl. We can do like the following:
fuse attention with FusionBartAttentionOpenai
fuse attention with FusionBartAttention
So that it can handle graph patterns of both OpenAI and HF.
@@ -59,6 +59,7 @@ def __init__(self, model_type): | |||
|
|||
if model_type == "clip": | |||
self.enable_embed_layer_norm = False | |||
self.model_impl = "hf" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need for this option. See another of my comment.
@shubhambhokare1 any update for this pr? |
#19854 has the same functionality without the requirement of adding new inputs to the subgraphs |
Depending on merge of #17316
Two additional inputs added to the encoderdecoderinit subgraph (for the first decoder run):
x = self.token_embedding(x) + self.positional_embedding[position_ids, :]
]