Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added directclr loss #963

Open
wants to merge 18 commits into
base: master
Choose a base branch
from

Conversation

Atharva-Phatak
Copy link
Contributor

Addded Implementation of DirectCLR loss as proposed in #781

@philippmwirth I need some help on writing tests. If you could guide me that would be amazing.

@codecov
Copy link

codecov bot commented Oct 20, 2022

Codecov Report

Base: 88.95% // Head: 88.46% // Decreases project coverage by -0.50% ⚠️

Coverage data is based on head (05dab52) compared to base (54cb38a).
Patch coverage: 41.66% of modified lines in pull request are covered.

❗ Current head 05dab52 differs from pull request most recent head 20b71f2. Consider uploading reports for the commit 20b71f2 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #963      +/-   ##
==========================================
- Coverage   88.95%   88.46%   -0.50%     
==========================================
  Files         108       96      -12     
  Lines        5023     4567     -456     
==========================================
- Hits         4468     4040     -428     
+ Misses        555      527      -28     
Impacted Files Coverage Δ
lightly/loss/directclr_loss.py 39.13% <39.13%> (ø)
lightly/loss/__init__.py 100.00% <100.00%> (ø)
lightly/api/version_checking.py 84.90% <0.00%> (-10.93%) ⬇️
lightly/api/api_workflow_download_dataset.py 87.36% <0.00%> (-7.51%) ⬇️
lightly/models/utils.py 77.98% <0.00%> (-2.02%) ⬇️
lightly/data/collate.py 92.85% <0.00%> (-1.20%) ⬇️
lightly/api/utils.py 93.33% <0.00%> (-0.86%) ⬇️
lightly/models/modules/heads.py 83.89% <0.00%> (-0.84%) ⬇️
lightly/loss/swav_loss.py 93.75% <0.00%> (-0.59%) ⬇️
lightly/models/modules/masked_autoencoder.py 92.00% <0.00%> (-0.21%) ⬇️
... and 32 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Contributor

@philippmwirth philippmwirth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Atharva-Phatak, thank you so much for implementing this! I have left some comments and a few commit suggestions but the overall contribution looks already great.

Regarding the tests, I'd suggest the following steps:

  1. create a new file under tests/loss/test_InfoNCELoss
  2. add a class class TestInfoNCELoss(unittest.TestCase)
  3. add functions of the form test_xyz(self) to verify that the loss is computed correctly (e.g. for a certain input we expect a certain output).

You can take a look at other examples for inspiration.


#Adapted from https://github.com/facebookresearch/directclr/blob/main/directclr/main.py
class InfoNCELoss(nn.Module):
"""Implementation of InfoNCELoss as required for DIRECTCLR"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can leave the reference to DIRECTCLR away here.

lightly/loss/directclr_loss.py Outdated Show resolved Hide resolved
lightly/loss/directclr_loss.py Show resolved Hide resolved
dim : Dimension of subvector to be used to compute InfoNCELoss.
temprature: The value used to scale logits.
"""
self.temprature = temprature
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo, it's temperature 🙂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry :(

lightly/loss/directclr_loss.py Outdated Show resolved Hide resolved
lightly/loss/directclr_loss.py Outdated Show resolved Hide resolved
#dimension of subvector sent to infoNCE
self.dim = dim

def normalize(self, x:torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would raise the question if it's necessary to put this in its own function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well technically not but I would avoid writing torch.nn.functional(x, dim = 1) again and again :)

lightly/loss/directclr_loss.py Outdated Show resolved Hide resolved
Atharva-Phatak and others added 5 commits October 21, 2022 10:18
Co-authored-by: Philipp Wirth <65946090+philippmwirth@users.noreply.github.com>
Co-authored-by: Philipp Wirth <65946090+philippmwirth@users.noreply.github.com>
Co-authored-by: Philipp Wirth <65946090+philippmwirth@users.noreply.github.com>
Co-authored-by: Philipp Wirth <65946090+philippmwirth@users.noreply.github.com>
Co-authored-by: Philipp Wirth <65946090+philippmwirth@users.noreply.github.com>
@Atharva-Phatak
Copy link
Contributor Author

I will add tests ASAP :)

@philippmwirth
Copy link
Contributor

I will add tests ASAP :)

Hey @Atharva-Phatak do you need more help regarding the tests?

@Atharva-Phatak
Copy link
Contributor Author

@philippmwirth I am sorry for the delay, I am currently busy with my mid-semesters. I have my last paper tomorrow, then I would add the tests at the EOD tomorrow.
But if you would like or this PR is delaying things, feel free to add tests :)

@philippmwirth
Copy link
Contributor

@philippmwirth I am sorry for the delay, I am currently busy with my mid-semesters. I have my last paper tomorrow, then I would add the tests at the EOD tomorrow. But if you would like or this PR is delaying things, feel free to add tests :)

No, don't worry! Good luck with your mid-semesters 🙂

@Atharva-Phatak
Copy link
Contributor Author

@philippmwirth Can you help me out here? I am not able to spot the mistake or Have I written the test wrong ?

@guarin
Copy link
Contributor

guarin commented Nov 23, 2022

Hi @Atharva-Phatak I believe you have to call super().__init__() in the __init__ function of the loss. Otherwise the nn.Module is not correctly initialized.

The error message indicates that some modules related attributes are missing:

self = <[AttributeError("'InfoNCELoss' object has no attribute '_modules'") raised in repr()] InfoNCELoss object at 0x7f4df2b735b0>
name = '_backward_hooks'

    def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
        if '_parameters' in self.__dict__:
            _parameters = self.__dict__['_parameters']
            if name in _parameters:
                return _parameters[name]
        if '_buffers' in self.__dict__:
            _buffers = self.__dict__['_buffers']
            if name in _buffers:
                return _buffers[name]
        if '_modules' in self.__dict__:
            modules = self.__dict__['_modules']
            if name in modules:
                return modules[name]
>       raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, name))
E       AttributeError: 'InfoNCELoss' object has no attribute '_backward_hooks'

You sadly have to scroll quite far back up in the unit tests output to find the actual stacktrace.

@Atharva-Phatak
Copy link
Contributor Author

@guarin Need help fixing this conflict.

@guarin
Copy link
Contributor

guarin commented Nov 28, 2022

Hi @Atharva-Phatak, sorry for the late reply but I fixed the merge conflict. There still seems to be an issue with the loss implementation though.

@Atharva-Phatak
Copy link
Contributor Author

Hi @Atharva-Phatak, sorry for the late reply but I fixed the merge conflict. There still seems to be an issue with the loss implementation though.

Let me check the implementation once again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants