Skip to content

Commit

Permalink
[release/2.1] fix tp_shard import (#2239)
Browse files Browse the repository at this point in the history
  • Loading branch information
chunyuan-w authored Nov 6, 2023
1 parent 7ccf44d commit 64b0681
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion intel_extension_for_pytorch/nn/utils/_weight_prepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def _all_reduce_and_bias_add(mp_group, original_bias, output):

def _pre_ipex_gemm(input, world_size, rank):
assert "deepspeed" in installed_pkg, "_pre_ipex_gemm requires deepspeed installed"
from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list
try:
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
except ImportError:
from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list

input_shard_size = get_shard_size(input.shape[-1], world_size)
input_shard_offset = sum(get_shard_size_list(input.shape[-1], world_size)[0:rank])
Expand Down

0 comments on commit 64b0681

Please sign in to comment.