Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add BasePredictionWriter 3/3 #7127

Merged
merged 59 commits into from
Apr 27, 2021
Merged

[feat] Add BasePredictionWriter 3/3 #7127

merged 59 commits into from
Apr 27, 2021

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Apr 20, 2021

What does this PR do?

Here is PR 1/2: #7141
Here is PR 2/3: #7215

This PR adds BasePredictionWriter.

        import torch
        import os
        from pytorch_lightning.callbacks import BasePredictionWriter

        class CustomWriter(BasePredictionWriter):

            def __init__(self, output_dir: str, write_interval: str):
                super().__init__(write_interval)
                self.output_dir

            def write_on_batch(
                self, trainer, pl_module: 'LightningModule', prediction: Any, batch_indices: List[int], batch: Any,
                batch_idx: int, dataloader_idx: int
            ):
                torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt")

            def write_on_epoch(self, trainer, pl_module: 'LightningModule', predictions: List[Any], batch_indices: List[Any]):
                torch.save(predictions, os.path.join(self.output_dir, "predictions.pt")

Fixes #7113

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@tchaton tchaton self-assigned this Apr 20, 2021
@tchaton tchaton added this to the v1.3 milestone Apr 20, 2021
@pep8speaks
Copy link

pep8speaks commented Apr 20, 2021

Hello @tchaton! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-04-27 19:59:37 UTC

@tchaton tchaton requested review from awaelchli and carmocca April 20, 2021 18:37
@codecov
Copy link

codecov bot commented Apr 20, 2021

Codecov Report

Merging #7127 (b3f60b8) into master (c6d9f52) will decrease coverage by 4%.
The diff coverage is 98%.

@@           Coverage Diff           @@
##           master   #7127    +/-   ##
=======================================
- Coverage      91%     87%    -4%     
=======================================
  Files         198     199     +1     
  Lines       12728   12772    +44     
=======================================
- Hits        11614   11145   -469     
- Misses       1114    1627   +513     

pytorch_lightning/callbacks/predictions.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/base.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/predictions.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/predictions.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/predictions.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/predictions.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/connectors/callback_connector.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/data_loading.py Outdated Show resolved Hide resolved
pytorch_lightning/overrides/distributed.py Outdated Show resolved Hide resolved
@tchaton tchaton marked this pull request as ready for review April 22, 2021 11:02
@tchaton tchaton changed the title [feat] Add BasePredictionWriter [feat] Add BasePredictionWriter 2/2 Apr 22, 2021
@carmocca carmocca added feature Is an improvement or enhancement callback labels Apr 27, 2021
Copy link
Member

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 😃 small changes

pytorch_lightning/callbacks/prediction_writer.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/prediction_writer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@Borda Borda requested a review from SeanNaren April 27, 2021 17:56
Args:
write_interval: When to write.

.. testcode::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not have any effect here in py files
cc: @awaelchli

pytorch_lightning/trainer/predict_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/properties.py Outdated Show resolved Hide resolved
@tchaton tchaton enabled auto-merge (squash) April 27, 2021 18:11
Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton the pr shows many coverage warnings

Comment on lines +1552 to +1554
def predict(
tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, datamodule=True, pbrr=None, use_callbacks=True
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to be careful that we don't overload these test helper functions with complexity. If the test functions are too complex we would need tests for the tests and so this goes in circles xD

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also rather keep them protected so noone would import them...

@mergify mergify bot removed the has conflicts label Apr 27, 2021
@tchaton tchaton merged commit e76ebd6 into master Apr 27, 2021
@tchaton tchaton deleted the predict_loop_1 branch April 27, 2021 20:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working callback feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PredictLoop: Missing hooks + results aren't being returned properly in spawn.
9 participants