diff --git a/mpa/modules/models/heads/custom_cls_head.py b/mpa/modules/models/heads/custom_cls_head.py index c805d4ce..e6e52324 100644 --- a/mpa/modules/models/heads/custom_cls_head.py +++ b/mpa/modules/models/heads/custom_cls_head.py @@ -2,6 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 # +import torch +import torch.nn.functional as F + from mmcls.models.builder import HEADS from mmcls.models.heads import LinearClsHead from .non_linear_cls_head import NonLinearClsHead @@ -83,6 +86,17 @@ def loss(self, cls_score, gt_label, feature=None): losses['loss'] = loss return losses + def simple_test(self, img): + """Test without augmentation.""" + cls_score = self.fc(img) + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + if torch.onnx.is_in_onnx_export(): + return cls_score + pred = F.softmax(cls_score, dim=1) if cls_score is not None else None + + return self.post_process(pred) + def forward_train(self, x, gt_label): cls_score = self.fc(x) losses = self.loss(cls_score, gt_label, feature=x) diff --git a/mpa/modules/models/heads/custom_hierarchical_linear_cls_head.py b/mpa/modules/models/heads/custom_hierarchical_linear_cls_head.py index 0a59bdd5..9b43738b 100644 --- a/mpa/modules/models/heads/custom_hierarchical_linear_cls_head.py +++ b/mpa/modules/models/heads/custom_hierarchical_linear_cls_head.py @@ -115,13 +115,17 @@ def simple_test(self, img): for i in range(self.hierarchical_info['num_multiclass_heads']): multiclass_logit = cls_score[:, self.hierarchical_info['head_idx_to_logits_range'][i][0]: self.hierarchical_info['head_idx_to_logits_range'][i][1]] - multiclass_logit = torch.softmax(multiclass_logit, dim=1) + if not torch.onnx.is_in_onnx_export(): + multiclass_logit = torch.softmax(multiclass_logit, dim=1) multiclass_logits.append(multiclass_logit) multiclass_pred = torch.cat(multiclass_logits, dim=1) if multiclass_logits else None if self.compute_multilabel_loss: multilabel_logits = cls_score[:, self.hierarchical_info['num_single_label_classes']:] - multilabel_pred = torch.sigmoid(multilabel_logits) if multilabel_logits is not None else None + if not torch.onnx.is_in_onnx_export(): + multilabel_pred = torch.sigmoid(multilabel_logits) if multilabel_logits is not None else None + else: + multilabel_pred = multilabel_logits if multiclass_pred is not None: pred = torch.cat([multiclass_pred, multilabel_pred], axis=1) else: diff --git a/mpa/modules/models/heads/custom_multi_label_linear_cls_head.py b/mpa/modules/models/heads/custom_multi_label_linear_cls_head.py index b29648a6..530ae1b6 100644 --- a/mpa/modules/models/heads/custom_multi_label_linear_cls_head.py +++ b/mpa/modules/models/heads/custom_multi_label_linear_cls_head.py @@ -81,9 +81,9 @@ def simple_test(self, img): cls_score = self.fc(img) * self.scale if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) - pred = torch.sigmoid(cls_score) if cls_score is not None else None if torch.onnx.is_in_onnx_export(): - return pred + return cls_score + pred = torch.sigmoid(cls_score) if cls_score is not None else None pred = list(pred.detach().cpu().numpy()) return pred diff --git a/mpa/modules/models/heads/custom_multi_label_non_linear_cls_head.py b/mpa/modules/models/heads/custom_multi_label_non_linear_cls_head.py index 755635da..3253f9f0 100644 --- a/mpa/modules/models/heads/custom_multi_label_non_linear_cls_head.py +++ b/mpa/modules/models/heads/custom_multi_label_non_linear_cls_head.py @@ -105,9 +105,9 @@ def simple_test(self, img): cls_score = self.classifier(img) * self.scale if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) - pred = torch.sigmoid(cls_score) if cls_score is not None else None if torch.onnx.is_in_onnx_export(): - return pred + return cls_score + pred = torch.sigmoid(cls_score) if cls_score is not None else None pred = list(pred.detach().cpu().numpy()) return pred diff --git a/mpa/modules/models/heads/non_linear_cls_head.py b/mpa/modules/models/heads/non_linear_cls_head.py index 7ce46d97..047d3d0b 100644 --- a/mpa/modules/models/heads/non_linear_cls_head.py +++ b/mpa/modules/models/heads/non_linear_cls_head.py @@ -76,9 +76,9 @@ def simple_test(self, img): cls_score = self.classifier(img) if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) - pred = F.softmax(cls_score, dim=1) if cls_score is not None else None if torch.onnx.is_in_onnx_export(): - return pred + return cls_score + pred = F.softmax(cls_score, dim=1) if cls_score is not None else None pred = list(pred.detach().cpu().numpy()) return pred