Skip to content

Commit

Permalink
Update nacl_loss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Bala93 authored Aug 7, 2024
1 parent d33f435 commit 7deb2cc
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions monai/losses/nacl_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
Converts the mask to one hot represenation and applies the spatial filter.
Args:
mask: the shape should be BHW[D]
mask: the shape should be BH[WD].
Returns:
torch.Tensor: the shape would be BNHW[D], N being number of classes.
torch.Tensor: the shape would be BNH[WD], N being number of classes.
"""
rmask: torch.Tensor

Expand All @@ -109,8 +109,8 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
Computes standard cross-entropy loss and constraints it neighbor aware logit penalty.
Args:
inputs: the shape should be BNHW[D], where N is the number of classes.
targets: the shape should be BHW[D].
inputs: the shape should be BNH[WD], where N is the number of classes.
targets: the shape should be BH[WD].
Returns:
torch.Tensor: value of the loss.
Expand All @@ -122,7 +122,7 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
>>> input = torch.rand(B, N, H, W)
>>> target = torch.randint(0, N, (B, H, W))
>>> criterion = NACLLoss(classes = N, dim = 2)
>>> loss = self(input, target)
>>> loss = criterion(input, target)
"""

loss_ce = self.cross_entropy(inputs, targets)
Expand Down

0 comments on commit 7deb2cc

Please sign in to comment.