Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit tests #83

Merged
merged 17 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ venv/

# Tox envs
.tox

# Aim
.aim
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
numpy
accelerate>=0.20.3
packaging
transformers>=4.34.1
transformers>=4.34.1,<4.38.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#53 is merged so you shouldnt need this cap, thanks

Copy link
Collaborator

@tedhtchang tedhtchang Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ssukriti Should we keep the cap(or a static version) for the transformers package avoid un intended errors like xla_fsdp_v2. We could create github workflow to run tests and then update the cap regularly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we dont need to keep static version , but yes in optional dependencies PR , @gkumbhat is looking into how to cap and we may cap to next major release. Now that CI/CD with automatically pull new release versions , if we see failing builds, we will update accordingly
the errors we were seeing with xla_fsdp_v2 was actually due to code we wrote , which was good to catch and fix . It was not a API change from transformers, but we were setting env variables incorrectly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general if there is a specific version that doesn't work, or has a bug ,then we can also ask pip to ignore that particular version.

#54 (comment)

torch
aim==3.18.1
sentencepiece
Expand Down
13 changes: 13 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
22 changes: 22 additions & 0 deletions tests/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright The IBM Tuning Team
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about the copyright notice..Where is this coming from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IBM Tuning Team was suggested by Raghu, the rest is from caikit

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpful datasets for configuring individual unit tests.
"""
# Standard
import os

### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json")
10 changes: 10 additions & 0 deletions tests/data/twitter_complaints_small.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"}
{"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"}
{"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint"}
{"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"}
{"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"}
{"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"}
{"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"}
{"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"}
{"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint"}
{"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint"}
47 changes: 47 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Third Party
import transformers

# Local
from tuning.config import configs, peft_config


def causal_lm_train_kwargs(train_kwargs):
"""Parse the kwargs for a valid train call to a Causal LM."""
parser = transformers.HfArgumentParser(
dataclass_types=(
configs.ModelArguments,
configs.DataArguments,
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
)
)
(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_dict(train_kwargs, allow_extra_keys=True)
return (
model_args,
data_args,
training_args,
lora_config
if train_kwargs.get("peft_method") == "lora"
else prompt_tuning_config,
)
111 changes: 111 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit Tests for SFT Trainer.
"""

# Standard
import json
import os
import tempfile

# First Party
from scripts.run_inference import TunedCausalLM
from tests.data import TWITTER_COMPLAINTS_DATA
from tests.helpers import causal_lm_train_kwargs

# Local
from tuning import sft_trainer

HAPPY_PATH_KWARGS = {
"model_name_or_path": "Maykeye/TinyLLama-v0",
"data_path": TWITTER_COMPLAINTS_DATA,
"num_train_epochs": 5,
"per_device_train_batch_size": 4,
"per_device_eval_batch_size": 4,
"gradient_accumulation_steps": 4,
"learning_rate": 0.00001,
"weight_decay": 0,
"warmup_ratio": 0.03,
"lr_scheduler_type": "cosine",
"logging_steps": 1,
"include_tokens_per_second": True,
"packing": False,
"response_template": "\n### Label:",
"dataset_text_field": "output",
"use_flash_attn": False,
"torch_dtype": "float32",
"model_max_length": 4096,
"peft_method": "pt",
"prompt_tuning_init": "RANDOM",
"num_virtual_tokens": 8,
"prompt_tuning_init_text": "hello",
"tokenizer_name_or_path": "Maykeye/TinyLLama-v0",
"save_strategy": "epoch",
}


def test_run_causallm_pt():
"""Check if we can bootstrap and run causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
HAPPY_PATH_KWARGS["output_dir"] = tempdir
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
HAPPY_PATH_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir, "PROMPT_TUNING")

# Load the tuned model
loaded_model = TunedCausalLM.load(
checkpoint_path=os.path.join(tempdir, "checkpoint-5"),
)

# Run inference on the text using the tuned model
loaded_model.run(
"Simply put, the theory of relativity states that ", max_new_tokens=500
)


def test_run_causallm_lora():
"""Check if we can bootstrap and run causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
HAPPY_PATH_KWARGS["output_dir"] = tempdir
HAPPY_PATH_KWARGS["peft_method"] = "lora"
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
HAPPY_PATH_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir, "LORA")

# Load the tuned model
loaded_model = TunedCausalLM.load(
checkpoint_path=os.path.join(tempdir, "checkpoint-5"),
)

# Run inference on the text using the tuned model
loaded_model.run(
"Simply put, the theory of relativity states that ", max_new_tokens=500
)


def _validate_training(tempdir, peft_type):
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
train_loss_file_path = "{}/train_loss.jsonl".format(tempdir)
assert os.path.exists(train_loss_file_path) == True
assert os.path.getsize(train_loss_file_path) > 0
adapter_config_path = os.path.join(tempdir, "checkpoint-1", "adapter_config.json")
assert os.path.exists(adapter_config_path)
with open(adapter_config_path) as f:
data = json.load(f)
assert data.get("peft_type") == peft_type
2 changes: 1 addition & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def train(
logger = logging.get_logger("sft_trainer")

# Validate parameters
if (not isinstance(train_args.num_train_epochs, float)) or (
if (not isinstance(train_args.num_train_epochs, (float, int))) or (
train_args.num_train_epochs <= 0
):
raise ValueError("num_train_epochs has to be an integer/float >= 1")
Expand Down
Loading