diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 548ba5eb..84666031 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -34,7 +34,6 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.utils import check_strided_sharding_enabled def parallelize_llama(