Skip to content

Commit

Permalink
Merge pull request #2265 from f4str/pt-object-detector-param
Browse files Browse the repository at this point in the history
Update `channels_first` parameter and docstring for PyTorch Object Detectors
  • Loading branch information
beat-buesser authored Sep 12, 2023
2 parents ad0a546 + 98016c6 commit e84835f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
18 changes: 9 additions & 9 deletions art/estimators/object_detection/pytorch_faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@

class PyTorchFasterRCNN(PyTorchObjectDetector):
"""
This class implements a model-specific object detector using Faster-RCNN and PyTorch following the input and output
This class implements a model-specific object detector using Faster R-CNN and PyTorch following the input and output
formats of torchvision.
"""

def __init__(
self,
model: Optional["torchvision.models.detection.fasterrcnn_resnet50_fpn"] = None,
model: Optional["torchvision.models.detection.FasterRCNN"] = None,
input_shape: Tuple[int, ...] = (-1, -1, -1),
optimizer: Optional["torch.optim.Optimizer"] = None,
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
channels_first: Optional[bool] = False,
channels_first: Optional[bool] = True,
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None,
preprocessing: "PREPROCESSING_TYPE" = None,
Expand All @@ -63,13 +63,13 @@ def __init__(
"""
Initialization.
:param model: Faster-RCNN model. The output of the model is `List[Dict[Tensor]]`, one for each input image. The
fields of the Dict are as follows:
:param model: Faster R-CNN model. The output of the model is `List[Dict[str, torch.Tensor]]`, one for
each input image. The fields of the Dict are as follows:
- boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values \
between 0 and H and 0 and W
- labels (Int64Tensor[N]): the predicted labels for each image
- scores (Tensor[N]): the scores or each prediction
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and
0 <= y1 < y2 <= H.
- labels [N]: the labels for each image.
- scores [N]: the scores of each prediction.
:param input_shape: The shape of one input sample.
:param optimizer: The optimizer for training the classifier.
:param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
Expand Down
14 changes: 7 additions & 7 deletions art/estimators/object_detection/pytorch_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
input_shape: Tuple[int, ...] = (-1, -1, -1),
optimizer: Optional["torch.optim.Optimizer"] = None,
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
channels_first: Optional[bool] = False,
channels_first: Optional[bool] = True,
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None,
preprocessing: "PREPROCESSING_TYPE" = None,
Expand All @@ -67,13 +67,13 @@ def __init__(
"""
Initialization.
:param model: Object detection model. The output of the model is `List[Dict[Tensor]]`, one for each input
image. The fields of the Dict are as follows:
:param model: Object detection model. The output of the model is `List[Dict[str, torch.Tensor]]`, one for
each input image. The fields of the Dict are as follows:
- boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values
between 0 and H and 0 and W
- labels (Int64Tensor[N]): the predicted labels for each image
- scores (Tensor[N]): the scores or each prediction
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and
0 <= y1 < y2 <= H.
- labels [N]: the labels for each image.
- scores [N]: the scores of each prediction.
:param input_shape: The shape of one input sample.
:param optimizer: The optimizer for training the classifier.
:param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
Expand Down
2 changes: 1 addition & 1 deletion art/estimators/object_detection/pytorch_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
"""
Initialization.
:param model: Object detection model wrapped as demonstrated in examples/get_started_yolo.py.
:param model: YOLO v3 or v5 model wrapped as demonstrated in examples/get_started_yolo.py.
The output of the model is `List[Dict[str, torch.Tensor]]`, one for each input image.
The fields of the Dict are as follows:
Expand Down

0 comments on commit e84835f

Please sign in to comment.