From 64b0681585ba0890a720e8c08d4e2d4cf5d64556 Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Mon, 6 Nov 2023 01:28:20 +0000 Subject: [PATCH] [release/2.1] fix tp_shard import (#2239) --- intel_extension_for_pytorch/nn/utils/_weight_prepack.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/intel_extension_for_pytorch/nn/utils/_weight_prepack.py b/intel_extension_for_pytorch/nn/utils/_weight_prepack.py index a4444e7b4..4ecaeec67 100644 --- a/intel_extension_for_pytorch/nn/utils/_weight_prepack.py +++ b/intel_extension_for_pytorch/nn/utils/_weight_prepack.py @@ -124,7 +124,10 @@ def _all_reduce_and_bias_add(mp_group, original_bias, output): def _pre_ipex_gemm(input, world_size, rank): assert "deepspeed" in installed_pkg, "_pre_ipex_gemm requires deepspeed installed" - from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list + try: + from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list + except ImportError: + from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list input_shard_size = get_shard_size(input.shape[-1], world_size) input_shard_offset = sum(get_shard_size_list(input.shape[-1], world_size)[0:rank])