Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Export of Openai Whisper [Batched decoding ver] #18815

Conversation

shubhambhokare1
Copy link
Contributor

@shubhambhokare1 shubhambhokare1 commented Dec 14, 2023

Depending on merge of #17316

Two additional inputs added to the encoderdecoderinit subgraph (for the first decoder run):

  • left_pad_mask: left pad mask is added to qk node in the qkv attention function
  • position_ids: used to select indices of positional embeddings [x = self.token_embedding(x) + self.positional_embedding[position_ids, :]]

@@ -0,0 +1,87 @@
# -------------------------------------------------------------------------

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning

Run lintrunner -a to apply this patch.
@@ -0,0 +1,486 @@
# -------------------------------------------------------------------------

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning

Run lintrunner -a to apply this patch.
@@ -312,8 +378,15 @@
"tensor(uint8)": np.uint8,
}

# Generate prompts
prompt_text = "Christians"
prompt_ids = processor.get_prompt_ids(prompt_text)

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable prompt\_ids is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
Comment on lines 129 to 134
self.attention_fusion = FusionBartAttentionOpenai(
self,
self.hidden_size,
self.num_heads,
self.attention_mask
)

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute attention_fusion, which was previously defined in superclass
BertOnnxModel
.
self.attention_mask
)
else:
self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask)

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute attention_fusion, which was previously defined in superclass
BertOnnxModel
.
@@ -343,21 +420,24 @@
diff = pt_outputs - ort_outputs
max_diff = max(diff.min(), diff.max(), key=abs)

if max_diff > 0:
if True:

Check warning

Code scanning / CodeQL

Constant in conditional expression or statement Warning

Testing a constant will always give the same result.
@@ -312,8 +378,15 @@
"tensor(uint8)": np.uint8,
}

# Generate prompts
prompt_text = "Christians"
prompt_ids = processor.get_prompt_ids(prompt_text)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable prompt_ids is not used.
@shubhambhokare1 shubhambhokare1 force-pushed the sbhokare/whisper_batch_decode branch from 1c2d388 to 12f0ff1 Compare February 13, 2024 20:04
@shubhambhokare1 shubhambhokare1 marked this pull request as ready for review February 13, 2024 20:07
@@ -8,6 +8,7 @@
import logging
import os
import tempfile
import copy

Check warning

Code scanning / lintrunner

RUFF/F811 Warning

Redefinition of unused copy from line 7.
See https://docs.astral.sh/ruff/rules/redefined-while-unused
@@ -6,6 +6,7 @@

import logging
import os
import io

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

@@ -19,6 +20,10 @@
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper

from whisper.model import Whisper, ModelDimensions

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

whisper.model.Whisper imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
@@ -19,6 +20,10 @@
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper

from whisper.model import Whisper, ModelDimensions

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

whisper.model.ModelDimensions imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
@@ -19,6 +20,10 @@
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper

from whisper.model import Whisper, ModelDimensions
from whisper import _MODELS, _ALIGNMENT_HEADS

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

whisper.\_MODELS imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
@@ -19,6 +20,10 @@
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper

from whisper.model import Whisper, ModelDimensions
from whisper import _MODELS, _ALIGNMENT_HEADS

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

whisper.\_ALIGNMENT\_HEADS imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
@@ -19,6 +20,10 @@
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper

from whisper.model import Whisper, ModelDimensions
from whisper import _MODELS, _ALIGNMENT_HEADS
from whisper import _download

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

whisper.\_download imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
@@ -8,6 +8,7 @@
import logging
import os
import tempfile
import copy

Check notice

Code scanning / CodeQL

Module is imported more than once Note

This import of module copy is redundant, as it was previously imported
on line 7
.
@@ -6,6 +6,7 @@

import logging
import os
import io

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'io' is not used.
@@ -19,6 +20,10 @@
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper

from whisper.model import Whisper, ModelDimensions

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Whisper' is not used.
Import of 'ModelDimensions' is not used.
@@ -19,6 +20,10 @@
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper

from whisper.model import Whisper, ModelDimensions
from whisper import _MODELS, _ALIGNMENT_HEADS

Check notice

Code scanning / CodeQL

Unused import Note

Import of '_MODELS' is not used.
Import of '_ALIGNMENT_HEADS' is not used.
@@ -19,6 +20,10 @@
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper

from whisper.model import Whisper, ModelDimensions
from whisper import _MODELS, _ALIGNMENT_HEADS
from whisper import _download

Check notice

Code scanning / CodeQL

Unused import Note

Import of '_download' is not used.
@@ -50,6 +50,8 @@ ONNX_OPERATOR_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 15) // 'left_pad_mask' needs to be on CPU
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update new input descriptions

@@ -374,8 +381,15 @@ def verify_onnx(
"tensor(uint8)": np.uint8,
}

# Generate prompts
prompt_text = "Christians"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean up testing module

Copy link
Contributor

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try running lintrunner -a to fix the lint issues.

@@ -124,7 +125,12 @@ class BartOnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size, model_impl="hf"):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)
self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
if model_impl == "openai":
Copy link
Contributor

@tianleiwu tianleiwu Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do not need model_impl. We can do like the following:
fuse attention with FusionBartAttentionOpenai
fuse attention with FusionBartAttention

So that it can handle graph patterns of both OpenAI and HF.

@@ -59,6 +59,7 @@ def __init__(self, model_type):

if model_type == "clip":
self.enable_embed_layer_norm = False
self.model_impl = "hf"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need for this option. See another of my comment.

@thiagocrepaldi
Copy link
Contributor

@shubhambhokare1 any update for this pr?

@shubhambhokare1
Copy link
Contributor Author

#19854 has the same functionality without the requirement of adding new inputs to the subgraphs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants