Skip to content

Commit

Permalink
fix dim
Browse files Browse the repository at this point in the history
  • Loading branch information
qingpeng9802 committed May 23, 2023
1 parent cb5ed04 commit 78eebaf
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# parameterize if necessary. (Or justify why the mean should be there)
average_spatial_dims = True
if average_spatial_dims:
loss = loss.mean(dim=target.shape[2:])
loss = loss.mean(dim=list(range(2, len(target.shape))))
loss = loss.sum()
elif self.reduction == LossReduction.MEAN.value:
loss = loss.mean()
Expand Down

0 comments on commit 78eebaf

Please sign in to comment.