Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled auto-truncation for any pretrained models #192

Merged
merged 14 commits into from
Jul 19, 2023
Merged

Enabled auto-truncation for any pretrained models #192

merged 14 commits into from
Jul 19, 2023

Conversation

Yerzhaisang
Copy link
Contributor

@Yerzhaisang Yerzhaisang commented Jul 16, 2023

Description

Initially some pretrained models like tas-b didn't truncate the doc and the docs with maximum length result in error. We made truncation parameter dynamic depending on the model if this is null.

Issues Resolved

Closes #132

Check List

  • New functionality includes testing.
    • All tests pass
  • Commits are signed per the DCO using --signoff

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.

@codecov
Copy link

codecov bot commented Jul 16, 2023

Codecov Report

Merging #192 (627229b) into main (7622af5) will increase coverage by 0.02%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #192      +/-   ##
==========================================
+ Coverage   91.06%   91.08%   +0.02%     
==========================================
  Files          37       37              
  Lines        4052     4062      +10     
==========================================
+ Hits         3690     3700      +10     
  Misses        362      362              
Impacted Files Coverage Δ
...search_py_ml/ml_models/sentencetransformermodel.py 71.27% <100.00%> (+0.76%) ⬆️

False
), f"Creating tokenizer.json file for tracing raised an exception {exec}"

assert tokenizer_json[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to assert max_length showing that we properly set the max_length

"truncation"
], "truncation parameter in tokenizer.json is null"

model11 = SentenceTransformer(model_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to load the model again here. Can't we do: test_model10.tokenizer.model_max_length in line 482?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I think we can't use SentenceTransformerModel object

image

Copy link
Collaborator

@dhrubo-os dhrubo-os Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, this is SentenceTransformerModel class not SentenceTransformer. In that case let's match with static value as we know the value already. I don't want to load models unnecessary as this can increase the overall execution time of integration tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Should I remove the last commit and make one more commit or can I make it without removing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to remove the last commit, you can push another commit with the modification.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name the variable: MAX_LENGTH_TASB and then assert that. please don't just assert with a number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"stride": 0,
}
with open(save_json_folder_path + "/tokenizer.json", "w") as file:
json.dump(parsed_json, file, indent=2)
Copy link
Collaborator

@dhrubo-os dhrubo-os Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After updating the file with new content did you compare both of the files (prev and new) on your end to verify? Please do a file comparison to make sure nothing else got updates except this object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understood, I should this comparison locally to make sure everything works as expected. Or should I add anything to the code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah compare locally to make sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

develop_tokenizer is the json generated on develop branch (new), main_tokenizer is previous one.
We see the difference only is the truncation parameter. After deleting these parameter, there is no difference between these jsons.

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you get develop_tokenizer? After saving the model with invoking the function save_as_pt then you load the tokenzier file?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to make sure when we are saving the file content with our changes, we aren't replacing the content, we are appending the content.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just started the kernel and ran the sixth cell (on main branch). Then I restarted the kernel and ran the seventh cell (on develop branch). And you see how I loaded tokenizers

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, sounds good. Thanks for the verification.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yerzhaisang Have you tried using this model with the doc with a token length exceeding 512 as mention in #132 ? Does it behave properly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it works as expected with length>1000. I tested it with @dhrubo-os during the office hour.

CHANGELOG.md Outdated
@@ -77,6 +77,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### Fixed
- Fixing documentation issue by @dhrubo-os in ([#20]https://github.com/opensearch-project/opensearch-py-ml/pull/20)
- Increment jenkins lib version and fix GHA job name by @gaiksaya in ([#37]https://github.com/opensearch-project/opensearch-py-ml/pull/37)
- Increment jenkins lib version and fix GHA job name by @gaiksaya in ([#37]https://github.com/opensearch-project/opensearch-py-ml/pull/37)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yerzhaisang Can you remove this duplicated line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I am sorry. Removed!

CHANGELOG.md Outdated
@@ -77,6 +77,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### Fixed
- Fixing documentation issue by @dhrubo-os in ([#20]https://github.com/opensearch-project/opensearch-py-ml/pull/20)
- Increment jenkins lib version and fix GHA job name by @gaiksaya in ([#37]https://github.com/opensearch-project/opensearch-py-ml/pull/37)
- Increment jenkins lib version and fix GHA job name by @gaiksaya in ([#37]https://github.com/opensearch-project/opensearch-py-ml/pull/37)
- Enabled auto-truncation for any pretrained models ([#192]https://github.com/opensearch-project/opensearch-py-ml/pull/192)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add this as a fix under [1.1.0].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yerzhaisang Could you please add this additional step to save_as_onnx as well to handle this problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhh, I forgot about that. Give me please 1 week, because I can do it on weekends.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think you can just copy the code there. It should not be different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -765,6 +765,18 @@ def save_as_pt(

# save tokenizer.json in save_json_folder_name
model.save(save_json_folder_path)
with open(save_json_folder_path + "/tokenizer.json") as user_file:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use os.path.join instead? String concat will fail if the user put / after the folder name, but os.path.join will not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, will be fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a line to join path and use this path for both read and write.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -765,6 +765,18 @@ def save_as_pt(

# save tokenizer.json in save_json_folder_name
model.save(save_json_folder_path)
with open(save_json_folder_path + "/tokenizer.json") as user_file:
file_contents = user_file.read()
parsed_json = json.loads(file_contents)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can combine line 769-770 and write just parsed_json = json.load(user_file). [load not loads]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

with open(save_json_folder_path + "/tokenizer.json") as user_file:
file_contents = user_file.read()
parsed_json = json.loads(file_contents)
if not parsed_json["truncation"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do if "truncation" not in parsed_json or parsed_json["truncation"] is None? I think we should handle the case where "truncation" is not in parsed_json similar to when parsed_json["truncation"] is None just in case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Signed-off-by: Yerzhaisang Taskali <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
@@ -851,6 +863,18 @@ def save_as_onnx(

# save tokenizer.json in output_path
model.save(save_json_folder_path)
tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about putting this code block into a separate common function that can be reused by save_as_pt and save_as_onnx? In that way, we can just test that function. No need to write separate test for save_as_pt or save_as_onnx to test this functionality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I should add one more function with typing and the description of inputs and outputs. Now I am going to sleep, Tomorrow after the work I will think how to implement this function and about its unit test.

Copy link
Contributor Author

@Yerzhaisang Yerzhaisang Jul 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dear @dhrubo-os , I think no need to change the unit test in order to avoid duplicating.
My recently implemented reusable fix_truncation function is used in save_as_pt and save_as_onnx in the same way. So, test_truncation_parameter unit test checks the work of fix_truncation used by save_as_pt function.
I think we should leave it as it's already implemented. But I will be happy to see your suggestions;)

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
@@ -701,6 +701,38 @@ def zip_model(
)
print("zip file is saved to " + zip_file_path + "\n")

def fix_truncation(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this a bit? Maybe handle_null_truncation or fill_null_truncation_field

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call out. And also as this is a private function let's rename it to: _fill_null_truncation_field

max_length: int,
) -> None:
"""
Description:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And say “Fill truncation field in tokenizer.json when it is null” here instead so other people know what exactly this function addresses without reading the code

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
@dhrubo-os dhrubo-os merged commit e0d1750 into opensearch-project:main Jul 19, 2023
14 checks passed
opensearch-trigger-bot bot pushed a commit that referenced this pull request Jul 19, 2023
* Made truncation parameter automatically processed

Signed-off-by: Yerzhaisang Taskali <tasqali1697@gmail.com>

* Made max_length parameter dynamic

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Added unit test for checking truncation parameter

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Updated CHANGELOG.md

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Included the test of max_length parameter value

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Slightly modeified the test of max_length parameter value

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Modified CHANGELOG.md and removed the duplicate

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Enabled auto-truncation format also for ONNX

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Implemented reusable function

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Fixed the lint

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Change tokenizer.json only if truncation is null

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Removed function which had been accidentally added

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Renamed reusable function and added the description

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Fixed the lint

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

---------

Signed-off-by: Yerzhaisang Taskali <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
(cherry picked from commit e0d1750)
dhrubo-os pushed a commit that referenced this pull request Jul 19, 2023
* Made truncation parameter automatically processed

Signed-off-by: Yerzhaisang Taskali <tasqali1697@gmail.com>

* Made max_length parameter dynamic

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Added unit test for checking truncation parameter

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Updated CHANGELOG.md

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Included the test of max_length parameter value

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Slightly modeified the test of max_length parameter value

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Modified CHANGELOG.md and removed the duplicate

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Enabled auto-truncation format also for ONNX

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Implemented reusable function

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Fixed the lint

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Change tokenizer.json only if truncation is null

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Removed function which had been accidentally added

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Renamed reusable function and added the description

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

* Fixed the lint

Signed-off-by: yerzhaisang <tasqali1697@gmail.com>

---------

Signed-off-by: Yerzhaisang Taskali <tasqali1697@gmail.com>
Signed-off-by: yerzhaisang <tasqali1697@gmail.com>
(cherry picked from commit e0d1750)

Co-authored-by: Yerzhaisang <55043014+Yerzhaisang@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Pre-trained tas-b model won't auto-truncate the doc
4 participants