Skip to content

PyTorch implementation of focal loss for multi-class semantic segmentation

Notifications You must be signed in to change notification settings

cloudpark93/pytorch-multi-class-segmentation-focal-loss

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 

Repository files navigation

Focal loss for Multi-class semantic segmentation in PyTorch

PyTorch implementation of focal loss for multi-class semantic segmentation.

If you want to use the alpha form focal loss, you need to do two things:

  1. Please prepare a set of alpha for each class.
  2. Change the comment in the code as below:
focal_loss = self.alpha[targets] * (1 - pt)**self.gamma * ce_loss
#focal_loss = (1 - pt) ** self.gamma * ce_loss

Non-alpha form Focal loss

...
fn_loss = FocalLoss()

pred = model(x)
loss = fn_loss(pred, target)

...

Alpha form Focal loss

...
class_weights = [a set of alpha for each class]
fn_loss = FocalLoss(alpha = class_weights)

pred = model(x)
loss = fn_loss(pred, target)

...

Extra

Please do visit my colleague's github as well!
https://github.com/jinsoo9595

About

PyTorch implementation of focal loss for multi-class semantic segmentation

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages