diff --git a/tests/acceleration/test_acceleration_dataclasses.py b/tests/acceleration/test_acceleration_dataclasses.py index 4be78e556..130159933 100644 --- a/tests/acceleration/test_acceleration_dataclasses.py +++ b/tests/acceleration/test_acceleration_dataclasses.py @@ -25,8 +25,8 @@ ) from tuning.config.acceleration_configs.attention_and_distributed_packing import ( AttentionAndDistributedPackingConfig, - PaddingFree, MultiPack, + PaddingFree, ) from tuning.config.acceleration_configs.fused_ops_and_kernels import ( FastKernelsConfig, diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index fafb04e74..31243c0fd 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -24,10 +24,10 @@ import torch # First Party +from tests.data import TWITTER_COMPLAINTS_JSON_FORMAT, TWITTER_COMPLAINTS_TOKENIZED from tests.test_sft_trainer import DATA_ARGS, MODEL_ARGS, PEFT_LORA_ARGS, TRAIN_ARGS # Local -from ..data import TWITTER_COMPLAINTS_JSON_FORMAT, TWITTER_COMPLAINTS_TOKENIZED from .spying_utils import create_mock_plugin_class_and_spy from tuning import sft_trainer from tuning.config.acceleration_configs import ( @@ -40,8 +40,8 @@ ) from tuning.config.acceleration_configs.attention_and_distributed_packing import ( AttentionAndDistributedPackingConfig, - PaddingFree, MultiPack, + PaddingFree, ) from tuning.config.acceleration_configs.fused_ops_and_kernels import ( FastKernelsConfig, @@ -596,6 +596,7 @@ def test_error_raised_with_paddingfree_and_flash_attn_disabled(): attention_and_distributed_packing_config=attention_and_distributed_packing_config, ) + def test_error_raised_with_multipack_and_paddingfree_disabled(): """Ensure error raised when padding-free is not used with flash attention""" with pytest.raises( diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index ef14f7932..1aad6abcb 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -123,7 +123,9 @@ def _verify_configured_dataclasses(self): # this also ensures that the attention implementation for multipack # will be flash attention as sfttrainer will enforce flash attn to be # set for padding free - assert self.padding_free is not None, "`--multipack` is currently only supported with `--padding_free`" + assert ( + self.padding_free is not None + ), "`--multipack` is currently only supported with `--padding_free`" @staticmethod def from_dataclasses(*dataclasses: Type):