Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added directclr loss #963

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lightly/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from lightly.loss.ntx_ent_loss import NTXentLoss
from lightly.loss.swav_loss import SwaVLoss
from lightly.loss.sym_neg_cos_sim_loss import SymNegCosineSimilarityLoss
from lightly.loss.directclr_loss import InfoNCELoss
46 changes: 46 additions & 0 deletions lightly/loss/directclr_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from cProfile import label
import torch
import torch.nn as nn

#Adapted from https://github.com/facebookresearch/directclr/blob/main/directclr/main.py
class InfoNCELoss(nn.Module):
"""Implementation of InfoNCELoss as required for DIRECTCLR"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can leave the reference to DIRECTCLR away here.

def __init__(self, dim:int ,temprature:float = 0.1):
Atharva-Phatak marked this conversation as resolved.
Show resolved Hide resolved
"""Parameters
Args:
Atharva-Phatak marked this conversation as resolved.
Show resolved Hide resolved
dim : Dimension of subvector to be used to compute InfoNCELoss.
temprature: The value used to scale logits.
"""
self.temprature = temprature
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo, it's temperature 🙂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry :(

#dimension of subvector sent to infoNCE
self.dim = dim

def normalize(self, x:torch.Tensor) -> torch.Tensor:
Atharva-Phatak marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would raise the question if it's necessary to put this in its own function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well technically not but I would avoid writing torch.nn.functional(x, dim = 1) again and again :)

"""Function to normalize the tensor
Args:
x : The torch tensor to be normalized.
Atharva-Phatak marked this conversation as resolved.
Show resolved Hide resolved
"""
return nn.functional.normalize(x, dim = 1)

def compute_loss(self, z1:torch.Tensor, z2:torch.Tensor) -> torch.Tensor:
"""Method to compute InfoNCELoss
Args:
z1,z2 : The representations from the encoder.
Atharva-Phatak marked this conversation as resolved.
Show resolved Hide resolved
"""
z1 = self.normalize(z1)
z2 = self.normalize(z2)
#DDP step
logits = z1 @ z2.T
logits = logits/self.temprature
labels = torch.arange(0, z2.shape[0]).type_as(logits)
loss = torch.nn.functional.cross_entropy(logits, labels)
return loss

def forward(self, z1:torch.Tensor, z2:torch.Tensor) -> torch.Tensor:
"""Forward Pass for InfoNCE computation"""
z1 = z1[:, :self.dim]
z2 = z2[:, :self.dim]
loss = self.compute_loss(z1, z2) + self.compute_loss(z2, z1)
return loss / 2

__all__ = ["InfoNCELoss"]