Skip to content

Commit

Permalink
additional fmt fixes
Browse files Browse the repository at this point in the history
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Aug 29, 2024
1 parent 9cb999e commit 00d17e7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/acceleration/test_acceleration_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
)
from tuning.config.acceleration_configs.attention_and_distributed_packing import (
AttentionAndDistributedPackingConfig,
PaddingFree,
MultiPack,
PaddingFree,
)
from tuning.config.acceleration_configs.fused_ops_and_kernels import (
FastKernelsConfig,
Expand Down
5 changes: 3 additions & 2 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import torch

# First Party
from tests.data import TWITTER_COMPLAINTS_JSON_FORMAT, TWITTER_COMPLAINTS_TOKENIZED
from tests.test_sft_trainer import DATA_ARGS, MODEL_ARGS, PEFT_LORA_ARGS, TRAIN_ARGS

# Local
from ..data import TWITTER_COMPLAINTS_JSON_FORMAT, TWITTER_COMPLAINTS_TOKENIZED
from .spying_utils import create_mock_plugin_class_and_spy
from tuning import sft_trainer
from tuning.config.acceleration_configs import (
Expand All @@ -40,8 +40,8 @@
)
from tuning.config.acceleration_configs.attention_and_distributed_packing import (
AttentionAndDistributedPackingConfig,
PaddingFree,
MultiPack,
PaddingFree,
)
from tuning.config.acceleration_configs.fused_ops_and_kernels import (
FastKernelsConfig,
Expand Down Expand Up @@ -596,6 +596,7 @@ def test_error_raised_with_paddingfree_and_flash_attn_disabled():
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
)


def test_error_raised_with_multipack_and_paddingfree_disabled():
"""Ensure error raised when padding-free is not used with flash attention"""
with pytest.raises(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def _verify_configured_dataclasses(self):
# this also ensures that the attention implementation for multipack
# will be flash attention as sfttrainer will enforce flash attn to be
# set for padding free
assert self.padding_free is not None, "`--multipack` is currently only supported with `--padding_free`"
assert (
self.padding_free is not None
), "`--multipack` is currently only supported with `--padding_free`"

@staticmethod
def from_dataclasses(*dataclasses: Type):
Expand Down

0 comments on commit 00d17e7

Please sign in to comment.