Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetinaNet object detection. #1697

Closed
wants to merge 31 commits into from
Closed

Conversation

hgaiser
Copy link
Contributor

@hgaiser hgaiser commented Dec 20, 2019

As briefly discussed in #1151 , this PR intends to serve as a discussion platform for the implementation of a RetinaNet network in torchvision.

The code is more like a skeleton implementation of what needs to be done, showcasing the design choices made so far. Before I start working on the todo's, I would like to discuss the current design in case it is not in line with torchvision.

The current list of todo's (also commented in the code) :

  • Implement focal loss (is it already somewhere in pytorch? Couldn't find it in the documentation page).
  • Use Smooth L1 loss for bbox regression, or use L1 like in Faster RCNN.
  • Move some functionality for anchor matching out of rpn.RegionProposalNetwork since we can share the code.
  • Implement functionality to decode bbox regression, similarly as with anchor matching, the goal is to share as much code as possible with rpn.RegionProposalNetwork.
  • Train resnet50_fpn on COCO.
  • Make sure it works with torchscript.
  • Test with a custom additional head.

Some design choices that might be worth discussing:

  • I decided not to inherit from GeneralizedRCNN for two reasons: it is a trivial implementation and it doesn't match with how RetinaNet works.
  • I put the compute_loss methods in the heads (RetinaNetClassificationHead / RetinaNetRegressionHead) as they are tightly correlated.
  • I made a single nn.Module to represent the RetinaNet heads, so that you should be able to add different heads by making a (sub)class like RetinaNetHead. This can be useful if you want to train other things than just classification and bbox regression. I think there is some more work required to allow a variable number of heads (with filtering the detections mainly), but I don't want to worry about that for now :). Since this was an easy thing to do I already implemented the head with that concept in mind.
  • I left num_classes to include the background, however the RetinaNet paper says they predict without a background class (so using sigmoid instead of softmax). This shouldn't be an issue I suppose, but it is worth noting. I left it like this because it is in line with the other implementations in torchvision. Personally I prefer to classify without a background class and using sigmoid mainly because it allows you to do multiclass classification, which softmax does not.
  • Currently rpn.RegionProposalNetwork is not usable in RetinaNet, and I think we shouldn't modify it to fit the use-case of RetinaNet either, but it does share a lot of the required functionality. I am thinking about how I can take some of the functionality out of rpn.RegionProposalNetwork, place it somewhere else so that both rpn.RegionProposalNetwork and RetinaNet can make use of it. The functionality is mainly the matching of predictions to ground truth and the decoding of regression values and anchors to bboxes.

@fmassa I would love to hear your opinion.

@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 4, 2020

@fmassa any feedback? I'm planning to continue work on this soon.

@codecov-io
Copy link

codecov-io commented Jan 4, 2020

Codecov Report

Merging #1697 into master will decrease coverage by 0.2%.
The diff coverage is 0%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master   #1697      +/-   ##
=========================================
- Coverage    0.48%   0.28%   -0.21%     
=========================================
  Files          92      93       +1     
  Lines        7411    7464      +53     
  Branches     1128    1017     -111     
=========================================
- Hits           36      21      -15     
- Misses       7362    7435      +73     
+ Partials       13       8       -5
Impacted Files Coverage Δ
torchvision/models/detection/retinanet.py 0% <0%> (ø)
torchvision/io/video.py 1.15% <0%> (-18.38%) ⬇️
torchvision/models/quantization/googlenet.py 0% <0%> (ø) ⬆️
torchvision/models/densenet.py 0% <0%> (ø) ⬆️
torchvision/ops/boxes.py 0% <0%> (ø) ⬆️
torchvision/models/detection/image_list.py 0% <0%> (ø) ⬆️
torchvision/ops/ps_roi_align.py 0% <0%> (ø) ⬆️
torchvision/datasets/folder.py 0% <0%> (ø) ⬆️
torchvision/ops/roi_align.py 0% <0%> (ø) ⬆️
torchvision/ops/_register_onnx_ops.py 0% <0%> (ø) ⬆️
... and 39 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e61538c...ebb741a. Read the comment docs.

@depthwise
Copy link

TBH I'm not entirely sure why not just port RetinaNet from FB's own Detectron2 (assuming TorchVision team is interested in something like that) - seems like it'd be much less work. I've been working with Detectron2's implementation for the past month or so, it's pretty good.

Torchvision and Detectron2 already share some of the code and where they do not, they at least share the conventions. I'm not a huge fan of Detectron2 leaking configuration details quite deep into the framework: I prefer to unwrap those early on, so that things are more composable.

Other than that, adding RetinaNet would be a largely clean, incremental change.

@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 5, 2020

I haven't looked at the implementation of Detectron2, I will have a look when I have time.

@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 10, 2020

TBH I'm not entirely sure why not just port RetinaNet from FB's own Detectron2 (assuming TorchVision team is interested in something like that) - seems like it'd be much less work. I've been working with Detectron2's implementation for the past month or so, it's pretty good.

My comments on the Detectron2 version:

  • There seems to be a tight integration with other parts of Detectron and their configuration file. It will be some work to decouple that correctly, but I think you already noticed that too.
  • Their RetinaNetHead is difficult to extend because of the choices in the design. My goal is to have a RetinaNetHead you can easily extend so that you can add other heads. This can be useful if you want to compute other things for each object (small example: the age of a person in a person detector).
  • Some parts, like the image preprocessing and ground truth matching, is already handled in torchvision in a different way. It doesn't make sense to port their methods when similar methods already exist.
  • I noticed they stick to the design of the original paper and use a sigmoid loss. Maybe we should do it too in torchvision, even though it conflicts with the classification losses in the other implementations, which all use softmax loss.

My overall conclusion is that it is a good example, but it would be more work to port it as closely as possible to torchvision, than it would be to implement retinanet in torchvision using the implementation of Detectron2 as inspiration.

Would this work for you @depthwise ?

@depthwise
Copy link

depthwise commented Jan 10, 2020

I've actually decoupled the config and made it more modular, but I've done it for a customer, so I can't share my work. And even if I could, it probably wouldn't be useful, since my version is now fixed-resolution and it's heavily customized for the task at hand.

Other changes I've implemented there:

  • Moved the regression encoder/decoder into a separate component
  • Moved the postprocessor (box decoder + nms) into a separate component
  • Rewrote FPN in a much simpler way that makes it easier to plug in other backbones, and makes it easier to replace it with BiFPN at some point
  • Rewrote trainer and eval to be able to easily obtain validation losses during the training run and to only save "best" checkpoints for each metric. I'm not sure how FB folks train without monitoring validation losses TBH. Detectron2's trainer makes this very difficult - validation losses are not computed during eval.

When it comes down to it, it was quite a bit of work, but at least when I put everything back together I didn't have to debug much.

The "neural" part of this model is pretty easy to do. It's the rest of the stuff that's a lot of work (ground truth assignment and eval mostly).

I think whatever you end up doing, you should probably aim for mostly the same design imperatives:

  • Make it easy to plug in other backbones (including "efficient" ones)
  • Make FPN implementation replaceable
  • Implement eval code which surfaces validation losses

But ultimately, project owners need to decide - I'm just a contributor.

@hgaiser hgaiser force-pushed the retinanet branch 2 times, most recently from 8303ff9 to daadf85 Compare January 10, 2020 12:38
@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 10, 2020

I've actually decoupled the config and made it more modular, but I've done it for a customer, so I can't share my work. And even if I could, it probably wouldn't be useful, since my version is now fixed-resolution and it's heavily customized for the task at hand.

I understand, thank you for sharing your insights!

* Moved the regression encoder/decoder into a separate component
* Moved the postprocessor (box decoder + nms) into a separate component

Yes sounds like a good idea.

Rewrote FPN in a much simpler way that makes it easier to plug in other backbones, and makes it easier to replace it with BiFPN at some point

I like that idea, but maybe we shouldn't focus on that in this PR.

* Rewrote trainer and eval to be able to easily obtain validation losses during the training run and to only save "best" checkpoints for each metric. I'm not sure how FB folks train without monitoring validation losses TBH. Detectron2's trainer makes this very difficult - validation losses are not computed during eval.

It makes sense not to compute the validation loss during eval, since you usually have no need for it at that time. It could be useful to add a flag to compute this optionally though. Is there a need for this @fmassa ?

I think whatever you end up doing, you should probably aim for mostly the same design imperatives:

* Make it easy to plug in other backbones (including "efficient" ones)

* Make FPN implementation replaceable

* Implement eval code which surfaces validation losses

I 100% agree on the first two items, I intend to rely on the existing code for that modularity. If that is insufficiently modular then I propose to improve the modularity in future PRs. Regarding the third item, I would ask @fmassa to provide feedback on how the torchvision team views this. I have no problem adding it, but I would definitely make it optional (and disabled by default).

@depthwise
Copy link

RE validation losses, consider this: you're already doing the entire computation needed to produce them, with the exception of ground truth encoding (which is not needed for inference per se). Loss computation is a small, incremental cost in the grand scheme of things. And with any real-world dataset, the loss imbalance or divergence often gives valuable insight into how to tune the network or hyperparameters. I have no problem with making it optional, of course, so as long as there is an option to easily turn them on.

@fmassa
Copy link
Member

fmassa commented Jan 14, 2020

Hi,

Thanks for opening the PR to start this discussion!

Sorry for the delay in replying, I was on holidays for some and then had a number of things to address once I came back from holidays.

About @depthwise comments

Thanks for the feedback! Definitely very valuable!

There are a number of points in this thread that deserves attention, but I would first decouple it from this PR. A few points that have been mentioned:

  • Make it easy to plug in other backbones (including "efficient" ones)
  • Make it easy to plug in other backbones (including "efficient" ones)
  • Implement eval code which surfaces validation losses

Those points are independent on RetinaNet, and as such I would encourage discussing it in a separate issue(s).

About @hgaiser points

I decided not to inherit from GeneralizedRCNN for two reasons: it is a trivial implementation and it doesn't match with how RetinaNet works.

I'm ok with that. In maskrcnn-benchmark we did use the same GeneralizedRCNN, but let's avoid adding extra complexity here. The goal is to have as simple to understand implementations as possible.

I put the compute_loss methods in the heads (RetinaNetClassificationHead / RetinaNetRegressionHead) as they are tightly correlated.

At first I think I would have kept them as part of the RetinaNetHead instead. This way, the user can more easily replace the RetinaNetClassificationHead with something else.

But then, this makes it harder to add extra losses for the classification or regression, so in this sense I think that your proposal makes a lot of sense.

Let's just double-check that it wouldn't be too complicated for an user to pass their own RetinaNetClassificationHead implementation and that they will be able to get the loss inherited from the RetinaNetClassificationHead. I'd love that it would just be a matter of inheriting from RetinaNetClassificationHead and customizing the modules / forward in init (without carrying the default layers from RetinaNetClassificationHead).

I made a single nn.Module to represent the RetinaNet heads, so that you should be able to add different heads by making a (sub)class like RetinaNetHead. This can be useful if you want to train other things than just classification and bbox regression. I think there is some more work required to allow a variable number of heads (with filtering the detections mainly), but I don't want to worry about that for now :). Since this was an easy thing to do I already implemented the head with that concept in mind.

This sounds good to me

I left num_classes to include the background, however the RetinaNet paper says they predict without a background class (so using sigmoid instead of softmax). This shouldn't be an issue I suppose, but it is worth noting. I left it like this because it is in line with the other implementations in torchvision. Personally I prefer to classify without a background class and using sigmoid mainly because it allows you to do multiclass classification, which softmax does not.

That's a good question. I would have liked that the models in torchvision to return its elements in the same order, irrespective of the model implementation. So including the background is what I would do as well. We can always in the details of the implementation remove the class corresponding to the background, in order to apply sigmoid, so this shouldn't be a problem I think.

Currently rpn.RegionProposalNetwork is not usable in RetinaNet, and I think we shouldn't modify it to fit the use-case of RetinaNet either, but it does share a lot of the required functionality. I am thinking about how I can take some of the functionality out of rpn.RegionProposalNetwork, place it somewhere else so that both rpn.RegionProposalNetwork and RetinaNet can make use of it. The functionality is mainly the matching of predictions to ground truth and the decoding of regression values and anchors to bboxes.

Great point. In maskrcnn-benchmark we tried to share as much as possible between both implementations, but it ended up adding a few weird abstractions. In its current state, I think that it might be ok to just copy-paste the implementations from RegionProposalNetwork. Let's start by copy-pasting the necessary functions (shouldn't be too long), and if we see nice potential refactorings then we can do it.

About @hgaiser TODO list:

Implement focal loss (is it already somewhere in pytorch? Couldn't find it in the documentation page).

It's not, but let's maybe just take the Python implementation from maskrcnn-benchmark and use it instead, no need to port the CUDA kernels, as torchscript can optimize it fairly well already.

Use Smooth L1 loss for bbox regression, or use L1 like in Faster RCNN.

Hum, that should be tested. I would love to just use l1 loss for simplicity, if it doesn't decrease results significantly

Move some functionality for anchor matching out of rpn.RegionProposalNetwork since we can share the code.

Hum, anchor matching is pretty much 3 lines of code now

gt_boxes = targets_per_image["boxes"]
match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
, so I'm not sure if sharing the code even further would be worth the added complexity. But I would love to see a proposal!

Implement functionality to decode bbox regression, similarly as with anchor matching, the goal is to share as much code as possible with rpn.RegionProposalNetwork.

I would maybe start with a baseline implementation which has some copy-paste, but which is simple to understand, currently the box decoding / post-processing is 30 lines of python. But again, would love to see a proposal!

Train resnet50_fpn on COCO.

Yes! Let us know if you would need help scaling up some training.

@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 14, 2020

Thank you for your detailed feedback @fmassa ! Some update from my side, I started trying to create an instance of a RetinaNet module and testing some things like anchor generation. As far as I can tell the anchors are correctly generated but I want to spend some more time on that to verify they are created as I expect them to be created. Note that this is pretty much unrelated to RetinaNet, but more for my own understanding of the anchor generation code in torchvision :).

Those points are independent on RetinaNet, and as such I would encourage discussing it in a separate issue(s).

Agreed.

I'm ok with that. In maskrcnn-benchmark we did use the same GeneralizedRCNN, but let's avoid adding extra complexity here. The goal is to have as simple to understand implementations as possible.

👍

At first I think I would have kept them as part of the RetinaNetHead instead. This way, the user can more easily replace the RetinaNetClassificationHead with something else.

But then, this makes it harder to add extra losses for the classification or regression, so in this sense I think that your proposal makes a lot of sense.

Let's just double-check that it wouldn't be too complicated for an user to pass their own RetinaNetClassificationHead implementation and that they will be able to get the loss inherited from the RetinaNetClassificationHead. I'd love that it would just be a matter of inheriting from RetinaNetClassificationHead and customizing the modules / forward in init (without carrying the default layers from RetinaNetClassificationHead).

👍

That's a good question. I would have liked that the models in torchvision to return its elements in the same order, irrespective of the model implementation. So including the background is what I would do as well.

Yes, I also see the value in keeping the same interface across the different implementations, but it will deviate a bit from the implementation from the focal loss paper. I understand your preference is to keep it as a softmax loss for now?

We can always in the details of the implementation remove the class corresponding to the background, in order to apply sigmoid, so this shouldn't be a problem I think.

I'm not sure I follow. Do you mean that later, if we would decide to use sigmoid instead, it is easy to refactor? Or do you mean some workaround where we compute sigmoid losses but output with a background class as if it comes from softmax?

Great point. In maskrcnn-benchmark we tried to share as much as possible between both implementations, but it ended up adding a few weird abstractions. In its current state, I think that it might be ok to just copy-paste the implementations from RegionProposalNetwork. Let's start by copy-pasting the necessary functions (shouldn't be too long), and if we see nice potential refactorings then we can do it.

Alright, I like that approach.

It's not, but let's maybe just take the Python implementation from maskrcnn-benchmark and use it instead, no need to port the CUDA kernels, as torchscript can optimize it fairly well already.

Agreed, shouldn't be that computationally intensive anyway.

Hum, that should be tested. I would love to just use l1 loss for simplicity, if it doesn't decrease results significantly

Yes that would be nice. For now I propose to use whatever I find, so probably L1, but later on we should compare it to a smooth L1 implementation.

so I'm not sure if sharing the code even further would be worth the added complexity. But I would love to see a proposal!

I hadn't really gotten to the details of this part, it seemed more complex but I still have to look at that part more closely. I will work on this soon (probably in the next 2/3 weeks) and I will post my findings here. If it is as you say it is, then simply copy pasting seems fine.

I would maybe start with a baseline implementation which has some copy-paste, but which is simple to understand, currently the box decoding / post-processing is 30 lines of python. But again, would love to see a proposal!

Same as above.

Yes! Let us know if you would need help scaling up some training.

Yes that would probably be nice when the time comes. Currently I'm limited to a single 1080Ti for testing :)

@fmassa
Copy link
Member

fmassa commented Jan 14, 2020

Some update from my side, I started trying to create an instance of a RetinaNet module and testing some things like anchor generation. As far as I can tell the anchors are correctly generated but I want to spend some more time on that to verify they are created as I expect them to be created. Note that this is pretty much unrelated to RetinaNet, but more for my own understanding of the anchor generation code in torchvision :).

Sounds good!

I'm not sure I follow. Do you mean that later, if we would decide to use sigmoid instead, it is easy to refactor? Or do you mean some workaround where we compute sigmoid losses but output with a background class as if it comes from softmax?

Well, I believe we can compute the sigmoid losses using only the foreground classes, and, if needed, append an extra row of all-zeros so that the output of the model follows the same convention as the others? Does this make sense?

@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 14, 2020

Well, I believe we can compute the sigmoid losses using only the foreground classes, and, if needed, append an extra row of all-zeros so that the output of the model follows the same convention as the others? Does this make sense?

Ah I see what you mean now. Hmm my gut feeling says it's a bad idea. The background class wouldn't be accurate anymore, it would always be zeros. If someone uses logic based on the value of the background score then this would break. Also, a background class to me indicates that softmax was used, so all scores should sum to 1, which wouldn't be the case anymore. This isn't necessarily dangerous, but may be confusing. I would choose either sigmoid and accept that it has a different output or use softmax and accept that it deviates from the original implementation.

@fmassa
Copy link
Member

fmassa commented Jan 14, 2020

Hum, after a second thought, the output of the model (during inference) is always the post-processed predictions, so as long as the class indices are correct, appending a background class or not shouldn't matter -- it will all be handled by the post-processing method of the RetinaNet class, right?

So from an external user-perspective, we can keep the same interface by using the sigmoid.
Or am I missing something?

@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 14, 2020

Hum, after a second thought, the output of the model (during inference) is always the post-processed predictions, so as long as the class indices are correct, appending a background class or not shouldn't matter -- it will all be handled by the post-processing method of the RetinaNet class, right?

So from an external user-perspective, we can keep the same interface by using the sigmoid.
Or am I missing something?

What is the output of the post processed predictions? I assume it is three tensors for the bboxes, labels, scores? If that is the case then I agree, from an external user perspective the interface would be identical.

Then I agree with what you're suggesting: perform a sigmoid loss and before post processing append a zeros vector for the background so that the indices match with the indices from the other models. This is a bit of a workaround, but to the user it is the same as the other models and internally it follows the design of the paper. Not ideal, but it is a good compromise.

@fmassa
Copy link
Member

fmassa commented Jan 14, 2020

We can also just add a + 1 to the labels predicted by the model, in order to account for the missing background class -- this seems a bit cleaner to me.

@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 14, 2020

We can also just add a + 1 to the labels predicted by the model, in order to account for the missing background class -- this seems a bit cleaner to me.

Haha yeah agreed. Depends a bit if the postprocessing does something with the background class, but otherwise that is cleaner. I will keep this in mind when looking at this part of the module.

@fmassa
Copy link
Member

fmassa commented Jan 29, 2020

@hgaiser let me know when you want me to have a closer look at your PR

Also, if you could rebase your PR on top of current master, we have fixed a number of issues with CI

@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 29, 2020

Yes will do @fmassa , thank you. Progress has been a bit slow because I only work on this task one day a week. I have pushed some changes last week, but nothing we didn't discuss already. Next steps will be to work out the loss function and to work towards training a first model. I will update this PR when I get to that stage.

@hgaiser
Copy link
Contributor Author

hgaiser commented Feb 1, 2020

I have added the loss functions and worked out some other issues. The classification loss is frequently giving NaNs and I can't exactly figure out yet where, but I haven't looked at it enough, so I wouldn't say I'm stuck yet, but that's where I'm at.

The regression seems to be working though; after some training on a simple image (black background with a single white square in the middle) I'm seeing outputs like this (the predicted and target regression values for a single anchor) :

predicted
 tensor([ 0.0517,  0.3787, -0.1293,  0.5424], device='cuda:0',
       grad_fn=<SliceBackward>)
target
 tensor([ 0.0526,  0.3793, -0.1310,  0.5447], device='cuda:0')

Next step would be to fix the classification loss and to add the translation from network outputs to detections. If that's done I can start a training process on a simple dataset and if that works out well, start it on COCO.

@fmassa if you could, would you like to review the current state?

@hgaiser
Copy link
Contributor Author

hgaiser commented Feb 2, 2020

Yay, the first trained object from retinanet in torchvision (the white square was the object, the blue outline the detected bounding box) :

Image_screenshot_02 02 2020

It was already working before, I just had to reduce the loss to avoid exploding gradients.

Next steps: implement post processing so that it actually outputs proper detections with NMS and everything (this visualization was hacked together using the raw output) and then start training on a more realistic dataset, eventually training on COCO.

@hgaiser hgaiser force-pushed the retinanet branch 2 times, most recently from b41f88a to f787a9d Compare February 7, 2020 15:31
@hgaiser
Copy link
Contributor Author

hgaiser commented Feb 14, 2020

The latest commit seems to fix an issue, it is working on the PennFudan dataset:

Screenshot from 2020-02-14 17-27-00

@fmassa could you review the PR to see if there is anything that needs to change? I haven't yet looked into torchscript or anything like that yet, nor have I started training on COCO.

@hgaiser hgaiser requested a review from fmassa February 14, 2020 16:29
@hgaiser
Copy link
Contributor Author

hgaiser commented Mar 12, 2020

@fmassa could I kindly request a review on this PR?

@fmassa
Copy link
Member

fmassa commented Mar 12, 2020

Hi @hgaiser ,

Very sorry for the delay in reviewing it, I was working towards an ECCV submission (which we submitted last week).

I'm separating some time tomorrow to review this PR.

@fmassa
Copy link
Member

fmassa commented Oct 13, 2020

#2784 has been merged, thanks a lot for all your work @hgaiser !

@fmassa fmassa closed this Oct 13, 2020
@bw4sz
Copy link
Contributor

bw4sz commented Jan 5, 2021

@hgaiser my project depended on your keras-retinanet repo. I'm thinking about following the change to pytorch, since we have a number of research goals related to the repo (https://besjournals.onlinelibrary.wiley.com/doi/full/10.1111/2041-210X.13472).

weecology/DeepForest-pytorch#1

Let me know how I can help with the retinanet testing and usage. It will help me make the transition. Seems like the way the community is moving. Not clear to me where to make a 'help wanted?' issue given that its wrapped into this larger workflow. We will be attacking things like 1) multi-temporal object detection with weak attention. 2) Recreating the semi-supervised workflow from https://www.biorxiv.org/content/10.1101/2020.09.08.287839v1.abstract. 3) filtering anchor proposals based on physical constraints of our target objects (trees in airborne imagery: http://visualize.idtrees.org/).

@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 5, 2021

Hey @bw4sz , that's great! I think most of the features have been merged successfully. One feature that I would like to add is the ability to extend the network with additional heads for custom predictions beyond just classifications. If you need this and are up for the task, I'm sure @fmassa is interested in a PR.

@fmassa
Copy link
Member

fmassa commented Jan 6, 2021

Also, one thing to keep in mind is that we might be refactoring a bit the current way the heads are implemented, so that it's easier to add new head architectures while keeping the same loss.

@bw4sz
Copy link
Contributor

bw4sz commented Jan 6, 2021 via email

@fmassa
Copy link
Member

fmassa commented Jan 11, 2021

@bw4sz yes, all you need is to feed an image normalized between 0-1 and the network will do the rest (scaling + mean normalization).

Also, you can copy-paste code in github by pasting the permalink of the lines you selected (select the lines and press "p" for the permalink, see attached picture for an example)
image

@ccl-private
Copy link

torchvision.models.detection.retinanet_resnet50_fpn inference slower than https://github.com/yhenon/pytorch-retinanet

@ccl-private
Copy link

why ?

@fmassa
Copy link
Member

fmassa commented Aug 12, 2021

@ccl-private we have optimized retinanet in #2828, which version of torchvision were you using? If you are still facing this issue, can you please open a new issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants