Skip to content

Commit

Permalink
Unit tests (#83)
Browse files Browse the repository at this point in the history
* Set up fixtures and data for tests

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Add basic unit tests

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Setting upper bound for transformers

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Ignore aim log files

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Include int num_train_epochs

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Fix formatting

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Add copyright notice

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Address review comments

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Run inference on tuned model

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Trainer downloads model

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* add more unit tests and refactor

Signed-off-by: Anh-Uong <anh.uong@ibm.com>

* Fix formatting

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Add FT unit test and refactor

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Removing transformers upper bound cap

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

* Address review comments

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>

---------

Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com>
Signed-off-by: Anh-Uong <anh.uong@ibm.com>
Co-authored-by: Anh-Uong <anh.uong@ibm.com>
  • Loading branch information
tharapalanivel and anhuong authored Mar 12, 2024
1 parent 0729820 commit a716cd7
Show file tree
Hide file tree
Showing 7 changed files with 421 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ venv/

# Tox envs
.tox

# Aim
.aim
13 changes: 13 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
22 changes: 22 additions & 0 deletions tests/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpful datasets for configuring individual unit tests.
"""
# Standard
import os

### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json")
10 changes: 10 additions & 0 deletions tests/data/twitter_complaints_small.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"}
{"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"}
{"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint"}
{"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"}
{"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"}
{"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"}
{"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"}
{"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"}
{"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint"}
{"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint"}
47 changes: 47 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Third Party
import transformers

# Local
from tuning.config import configs, peft_config


def causal_lm_train_kwargs(train_kwargs):
"""Parse the kwargs for a valid train call to a Causal LM."""
parser = transformers.HfArgumentParser(
dataclass_types=(
configs.ModelArguments,
configs.DataArguments,
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
)
)
(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_dict(train_kwargs, allow_extra_keys=True)
return (
model_args,
data_args,
training_args,
lora_config
if train_kwargs.get("peft_method") == "lora"
else prompt_tuning_config,
)
Loading

0 comments on commit a716cd7

Please sign in to comment.