Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve FSDP + QLora #25

Merged
merged 10 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "DataDreamer"
version = "0.31.0"
version = "0.32.0"
description = "Prompt. Generate Synthetic Data. Train & Align Models."
license = "MIT"
authors= [
Expand Down
2 changes: 1 addition & 1 deletion scripts/.cluster/slurm/_sbatch_config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#SBATCH --output=.cluster/slurm/.last_job/submission.out
#SBATCH --ntasks 1
#SBATCH --cpus-per-task 16
#SBATCH --mem=10G
#SBATCH --mem=30G
#SBATCH --gpus=2

# Source the user's bashrc
Expand Down
10 changes: 9 additions & 1 deletion src/_cachable/_cachable.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _notify_adaptive_batch_sizing(model_logger: Logger, progress_state: dict[str
class _StrWithSeed(str):
seed: Any

def __new__(cls, value: str, seed: "Any | _StrWithSeed"):
def __new__(cls, value: str, seed: "Any | _StrWithSeed" = None):
obj = str.__new__(cls, value)
obj.seed = seed.seed if isinstance(seed, _StrWithSeed) else seed
return obj
Expand All @@ -75,6 +75,14 @@ def __eq__(self, __value: object) -> bool:
def __hash__(self):
return hash((self.seed, str(self)))

def __getstate__(self):
state = {"str": str(self), "seed": self.seed}

return state

def __setstate__(self, state):
self.seed = state["seed"]

@staticmethod
def total_per_input_seeds(inputs: list["str | _StrWithSeed"]) -> int:
return sum(
Expand Down
4 changes: 4 additions & 0 deletions src/embedders/sentence_transformers_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from ..utils.hf_model_utils import (
convert_dtype,
filter_model_warnings,
get_model_max_context_length,
get_tokenizer,
)
Expand Down Expand Up @@ -122,6 +123,9 @@ def model(self) -> SentenceTransformer:
# torch._dynamo.config.suppress_errors = True
# model = torch.compile(model)

# Filter any warnings from the model
filter_model_warnings()

# Finished loading
log_if_timeout.stop(
partial(
Expand Down
13 changes: 4 additions & 9 deletions src/llms/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
HF_TRANSFORMERS_CITATION,
PEFT_CITATION,
convert_dtype,
filter_model_warnings,
get_attn_implementation,
get_config,
get_model_max_context_length,
Expand Down Expand Up @@ -273,6 +274,9 @@ def model(self) -> PreTrainedModel:
torch._dynamo.config.suppress_errors = True
model = torch.compile(model)

# Filter any warnings from the model
filter_model_warnings()

# Finished loading
log_if_timeout.stop(
partial(
Expand Down Expand Up @@ -323,15 +327,6 @@ def count_tokens(self, value: str) -> int:
Returns:
The number of tokens in the string.
"""
pass
"""_summary_

Args:
value (_type_): _description_

Returns:
_type_: _description_
"""
return len(self.tokenizer.encode(value))

@torch.no_grad()
Expand Down
4 changes: 4 additions & 0 deletions src/llms/petals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..utils.arg_utils import AUTO, Default
from ..utils.background_utils import RunIfTimeout
from ..utils.fs_utils import safe_fn
from ..utils.hf_model_utils import filter_model_warnings
from ..utils.import_utils import ignore_hivemind_warnings, ignore_transformers_warnings
from .hf_transformers import HFTransformers

Expand Down Expand Up @@ -161,6 +162,9 @@ def model(self) -> PreTrainedModel:
# torch._dynamo.config.suppress_errors = True
# model = torch.compile(model)

# Filter any warnings from the model
filter_model_warnings()

# Finished loading
log_if_timeout.stop(
partial(
Expand Down
4 changes: 4 additions & 0 deletions src/task_models/hf_classification_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
HF_TRANSFORMERS_CITATION,
PEFT_CITATION,
convert_dtype,
filter_model_warnings,
get_config,
get_model_max_context_length,
get_tokenizer,
Expand Down Expand Up @@ -152,6 +153,9 @@ def model(self) -> PreTrainedModel:
# torch._dynamo.config.suppress_errors = True
# model = torch.compile(model)

# Filter any warnings from the model
filter_model_warnings()

# Finished loading
log_if_timeout.stop(
partial(
Expand Down
22 changes: 21 additions & 1 deletion src/tests/llms/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
from time import sleep
from types import GeneratorType

import dill
import psutil
import pytest
import torch
from flaky import flaky
from sortedcontainers import SortedDict

from ... import DataDreamer
from ..._cachable._cachable import _is_primitive
from ..._cachable._cachable import _is_primitive, _StrWithSeed
from ...llms import (
AI21,
VLLM,
Expand Down Expand Up @@ -338,6 +339,25 @@ def test_is_primitive(self):
assert _is_primitive({"foo": 5})
assert not _is_primitive({"foo": object()})

def test_StrWithSeed(self):
seed_a = _StrWithSeed("hello", seed=1)
seed_b = _StrWithSeed("hello", seed=2)
seed_c = _StrWithSeed("hello", seed=1)
assert (
isinstance(seed_a, str)
and isinstance(seed_b, str)
and isinstance(seed_c, str)
)
assert seed_a.seed == 1
assert seed_b.seed == 2
assert seed_c.seed == 1
assert str(seed_a) == "hello"
assert str(seed_b) == "hello"
assert str(seed_c) == "hello"
assert hash(seed_a) != hash(seed_b)
assert hash(seed_a) == hash(seed_c)
assert hash(seed_a) == hash(dill.loads(dill.dumps(seed_c)))

def test_check_temperature_and_top_p(self):
assert _check_temperature_and_top_p(
temperature=0.3,
Expand Down
5 changes: 3 additions & 2 deletions src/tests/trainers/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
TrainHFPPO,
TrainSentenceTransformer,
)
from ...trainers._train_hf_base import CustomDataCollatorWithPadding
from ...utils.arg_utils import AUTO
from ...utils.hf_model_utils import get_orig_model, is_bnb_quantized
from ...utils.hf_training_utils import CustomDataCollatorWithPadding
from ...utils.import_utils import ignore_transformers_warnings

with ignore_transformers_warnings():
Expand Down Expand Up @@ -422,12 +422,13 @@ def test_fsdp_peft(self, qlora, create_datadreamer, mocker):
validation_output=val_dataset.output["outputs"],
epochs=1,
batch_size=8,
gradient_checkpointing=qlora,
)
assert data_collator_spy.call_count == 0
trainer_path = cast(str, trainer._output_folder_path)
with open(os.path.join(trainer_path, "fingerprint.json"), "r") as f:
assert (
json.load(f) == "ce4179deefbddefd" if qlora else "6b385aca0ce684b3"
json.load(f) == "42a7bd193f804a4a" if qlora else "6b385aca0ce684b3"
)
assert train_result is trainer
assert (
Expand Down
2 changes: 1 addition & 1 deletion src/tests/trainers/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
TrainSentenceTransformer,
TrainSetFitClassifier,
)
from ...trainers._train_hf_base import CustomDataCollatorWithPadding
from ...utils.fs_utils import clear_dir
from ...utils.hf_model_utils import get_orig_model, validate_peft_config
from ...utils.hf_training_utils import CustomDataCollatorWithPadding
from ...utils.import_utils import ignore_transformers_warnings

with ignore_transformers_warnings():
Expand Down
5 changes: 4 additions & 1 deletion src/tests/utils/test_device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,7 @@ def test_get_device_env_variables(self):
get_device_env_variables([0, 2, 999999, 0, 1, -1, -1])
with pytest.raises(AssertionError):
get_device_env_variables([0, 2, 0, 1])
assert get_device_env_variables([0, 2, 1]) == {"CUDA_VISIBLE_DEVICES": "6,3,4"}
assert get_device_env_variables([0, 2, 1]) == {
"CUDA_VISIBLE_DEVICES": "6,3,4",
"NCCL_P2P_DISABLE": "1",
}
Loading
Loading