-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch Object Detector Training Implementation and Prediction Update #2067
PyTorch Object Detector Training Implementation and Prediction Update #2067
Conversation
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Codecov Report
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more @@ Coverage Diff @@
## dev_1.14.0 #2067 +/- ##
==============================================
+ Coverage 77.23% 80.70% +3.47%
==============================================
Files 294 294
Lines 26212 26322 +110
Branches 4797 4827 +30
==============================================
+ Hits 20244 21244 +1000
+ Misses 4822 3914 -908
- Partials 1146 1164 +18
|
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @f4str Thank you very much for your pull request. I have added a few review comments, what do you think?
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
# Apply postprocessing | ||
predictions = self._apply_postprocessing(preds=results_list, fit=False) | ||
|
||
return predictions # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Apply postprocessing | |
predictions = self._apply_postprocessing(preds=results_list, fit=False) | |
return predictions # type: ignore | |
return predictions |
# Apply postprocessing | ||
predictions = self._apply_postprocessing(preds=results_list, fit=False) | ||
|
||
return predictions # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Apply postprocessing | |
predictions = self._apply_postprocessing(preds=results_list, fit=False) | |
return predictions # type: ignore | |
return predictions |
# Apply postprocessing | ||
predictions = self._apply_postprocessing(preds=results_list, fit=False) | ||
|
||
return predictions # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Apply postprocessing | |
predictions = self._apply_postprocessing(preds=results_list, fit=False) | |
return predictions # type: ignore | |
return predictions |
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@f4str Looks good to me! Thank you very much.
Description
Implementation of the
fit
method for thePyTorchObjectDetector
,PyTorchFasterRCNN
, andPyTorchYolo
object detectors. This allows these models to be trained using ART. The notebook innotebooks/poisoning_attack_bad_det_rma.ipynb
has been updated to demonstrate training these models. This is a partial implementation of #2058 as thefit
method needs to also be implemented for the TensorFlow object detectors.The
predict
method in thePyTorchObjectDetector
andPyTorchYolo
classes has been rewritten to implement the following new features:torch.no_grad()
scope is used like thePyTorchClassifier
to prevent extraneous gradients accumulating in the model which causes slow down.channels_first
parameter is now used to determine whether the input needs to be transformed. The default behavior remains the same.Additionally, the unit tests for the
test_pytorch_object_detector.py
andtest_pytorch_faster_rcnn.py
were rewritten in pytest with a new test added for model training.Type of change
Please check all relevant options.
Testing
Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.
PyTorchObjectDetector
(rewritten in pytest)PyTorchFasterRCNN
(rewritten in pytest)PyTorchYolo
Test Configuration:
Checklist