Skip to content

Commit

Permalink
fix recompute (#42128)
Browse files Browse the repository at this point in the history
* fix recompute

* modify return
  • Loading branch information
sljlp authored Apr 25, 2022
1 parent f4ce8a9 commit f21824d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/paddle/incubate/distributed/models/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def forward(self, inp):
def experts_fwd(x, fwd_expert_count, experts):

if x.shape[0] == 0:
return paddle.empty(x.shape, x.dtype)
return x
y = []
last_index = 0
assert isinstance(fwd_expert_count, np.ndarray)
Expand All @@ -411,7 +411,7 @@ def experts_fwd(x, fwd_expert_count, experts):
last_index = expert_count + last_index
return paddle.concat(y, axis=0)

if self.recompute_interval <= 0:
if self.recompute_interval <= 0 or x.shape[0] == 0:
x = experts_fwd(x, fwd_expert_count.numpy(), self.experts)
else:
x = _hp_recompute(experts_fwd, x,
Expand Down

0 comments on commit f21824d

Please sign in to comment.