Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qualcomm AI Engine Direct - support embedding op #2057

Closed
wants to merge 4 commits into from

Conversation

haowhsu-quic
Copy link
Collaborator

@haowhsu-quic haowhsu-quic commented Feb 23, 2024

summary:

  • support embedding op with int32 index input
  • make mobilebert / llama2 be fully delegated
  • add requantize passes for mixed precision
  • bug fixes

Copy link

pytorch-bot bot commented Feb 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/2057

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f223b65 with merge base 81b3232 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 23, 2024
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cccclai
Copy link
Contributor

cccclai commented Feb 29, 2024

Two main comments:

  • Is there a way to repro the error on our side? maybe it's an edge case we'd need to fix
  • Can we leave a todo (maybe an expected failing unit test)?

@haowhsu-quic
Copy link
Collaborator Author

haowhsu-quic commented Feb 29, 2024

Two main comments:

  • Is there a way to repro the error on our side? maybe it's an edge case we'd need to fix
  • Can we leave a todo (maybe an expected failing unit test)?

Thanks for reviewing! I've added TODO items for next action.

It could be reproduced with following patch:

[examples/qualcomm/scripts/mobilebert_fine_tune.py]
@@ -58,10 +58,18 @@ def accuracy_per_class(preds, goldens, labels):
def get_dataset(data_val):
    # prepare input data
    inputs, input_list = [], ""
-   # max_position_embeddings defaults to 512
-   position_ids = torch.arange(512).expand((1, -1)).to(torch.int32)
    for index, data in enumerate(data_val):
        data = [d.to(torch.int32) for d in data]
+       # input_ids, attention_mask, token_type_ids
+       inputs.append((*data[:2], torch.zeros(data[0].size(), dtype=torch.int32)))
-       # input_ids, attention_mask, token_type_ids, position_ids
-       inputs.append(
-           (
-               *data[:2],
-               torch.zeros(data[0].size(), dtype=torch.int32),
-               position_ids[:, : data[0].shape[1]],
-           )
-       )
        input_text = " ".join(
            [f"input_{index}_{i}.raw" for i in range(len(inputs[-1]))]
        )
@@ -204,9 +212,6 @@ def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size):
            map_location=torch.device("cpu"),
        ),
    )
+   # hack for changing dtype of "position_ids" from int64 to int32
+   sub_module = model.mobilebert.embeddings
+   sub_module.position_ids = sub_module.position_ids.to(torch.int32)

    return model.eval(), dataloader_val, labels

@cccclai
Copy link
Contributor

cccclai commented Feb 29, 2024

Thanks for the update. Seems like some llama related CI jobs start failign. Those changes look legit but Is it okay to remove changes in examples/models/llama2/model.py to get CI green? We can do a seperate pr for this.

@haowhsu-quic
Copy link
Collaborator Author

Thank you, I moved the datatype casting from examples/models/llama2/model.py to our own script as mobilebert does.

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

summary:
- support embedding op with int32 index input
- make mobilebert / llama2 be fully delegated
- add requantize passes for mixed precision
- bug fixes
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@cccclai merged this pull request in 57e192b.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants