-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Set up fixtures and data for tests Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Add basic unit tests Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Setting upper bound for transformers Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Ignore aim log files Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Include int num_train_epochs Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Fix formatting Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Add copyright notice Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Address review comments Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Run inference on tuned model Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Trainer downloads model Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * add more unit tests and refactor Signed-off-by: Anh-Uong <anh.uong@ibm.com> * Fix formatting Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Add FT unit test and refactor Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Removing transformers upper bound cap Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> * Address review comments Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --------- Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Signed-off-by: Anh-Uong <anh.uong@ibm.com> Co-authored-by: Anh-Uong <anh.uong@ibm.com>
- Loading branch information
1 parent
0729820
commit a716cd7
Showing
7 changed files
with
421 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,6 @@ venv/ | |
|
||
# Tox envs | ||
.tox | ||
|
||
# Aim | ||
.aim |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright The IBM Tuning Team | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright The IBM Tuning Team | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Helpful datasets for configuring individual unit tests. | ||
""" | ||
# Standard | ||
import os | ||
|
||
### Constants used for data | ||
DATA_DIR = os.path.join(os.path.dirname(__file__)) | ||
TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
{"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"} | ||
{"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"} | ||
{"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint"} | ||
{"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"} | ||
{"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"} | ||
{"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"} | ||
{"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"} | ||
{"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"} | ||
{"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint"} | ||
{"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright The IBM Tuning Team | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Third Party | ||
import transformers | ||
|
||
# Local | ||
from tuning.config import configs, peft_config | ||
|
||
|
||
def causal_lm_train_kwargs(train_kwargs): | ||
"""Parse the kwargs for a valid train call to a Causal LM.""" | ||
parser = transformers.HfArgumentParser( | ||
dataclass_types=( | ||
configs.ModelArguments, | ||
configs.DataArguments, | ||
configs.TrainingArguments, | ||
peft_config.LoraConfig, | ||
peft_config.PromptTuningConfig, | ||
) | ||
) | ||
( | ||
model_args, | ||
data_args, | ||
training_args, | ||
lora_config, | ||
prompt_tuning_config, | ||
) = parser.parse_dict(train_kwargs, allow_extra_keys=True) | ||
return ( | ||
model_args, | ||
data_args, | ||
training_args, | ||
lora_config | ||
if train_kwargs.get("peft_method") == "lora" | ||
else prompt_tuning_config, | ||
) |
Oops, something went wrong.