Skip to content

Commit

Permalink
Make mean teacher algorithm consider distributed training (#2729)
Browse files Browse the repository at this point in the history
* make mean_teacher consider distributed training

* align with pre-commit

* re-enable test case

* move tensor not to cuda but current device

* apply comment
  • Loading branch information
eunwoosh authored Dec 20, 2023
1 parent 6a55742 commit 34e0ec2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mmdet.core.mask.structures import BitmapMasks
from mmdet.models import DETECTORS, build_detector
from mmdet.models.detectors import BaseDetector
from torch import distributed as dist

from otx.utils.logger import get_logger

Expand Down Expand Up @@ -182,22 +183,29 @@ def forward_train(
pseudo_bboxes, pseudo_labels, pseudo_masks, pseudo_ratio = self.generate_pseudo_labels(
teacher_outputs, device=current_device, img_meta=ul_img_metas, **kwargs
)
if self.filter_empty_annotations:
non_empty = [bool(len(i)) for i in pseudo_labels]
pseudo_bboxes = [pb for i, pb in enumerate(pseudo_bboxes) if non_empty[i]]
pseudo_labels = [pl for i, pl in enumerate(pseudo_labels) if non_empty[i]]
pseudo_masks = [pm for i, pm in enumerate(pseudo_masks) if non_empty[i]]
ul_img_metas = [im for i, im in enumerate(ul_img_metas) if non_empty[i]]
ul_img = ul_img[non_empty]
else:
non_empty = [True]
if self.visualize:
self._visual_online(ul_img, pseudo_bboxes, pseudo_labels)
non_empty = [bool(len(i)) for i in pseudo_labels] if self.filter_empty_annotations else [True]
get_unlabeled_loss = pseudo_ratio >= self.min_pseudo_label_ratio and any(non_empty)

if dist.is_initialized():
reduced_get_unlabeled_loss = torch.tensor(int(get_unlabeled_loss)).to(current_device)
dist.all_reduce(reduced_get_unlabeled_loss)
if dont_have_to_train := not get_unlabeled_loss and reduced_get_unlabeled_loss > 0:
get_unlabeled_loss = True
non_empty[0] = True

losses.update(ps_ratio=torch.tensor([pseudo_ratio], device=current_device))

# Unsupervised loss
# Compute only if min_pseudo_label_ratio is reached
if pseudo_ratio >= self.min_pseudo_label_ratio and any(non_empty):
if get_unlabeled_loss:
if self.filter_empty_annotations:
pseudo_bboxes = [pb for i, pb in enumerate(pseudo_bboxes) if non_empty[i]]
pseudo_labels = [pl for i, pl in enumerate(pseudo_labels) if non_empty[i]]
pseudo_masks = [pm for i, pm in enumerate(pseudo_masks) if non_empty[i]]
ul_img_metas = [im for i, im in enumerate(ul_img_metas) if non_empty[i]]
ul_img = ul_img[non_empty]
if self.visualize:
self._visual_online(ul_img, pseudo_bboxes, pseudo_labels)
if self.bg_loss_weight >= 0.0:
self.model_s.bbox_head.bg_loss_weight = self.bg_loss_weight
if self.model_t.with_mask:
Expand All @@ -214,7 +222,10 @@ def forward_train(
if ul_loss_name.startswith("loss_"):
ul_loss = ul_losses[ul_loss_name]
target_loss = ul_loss_name.split("_")[-1]
if self.unlabeled_loss_weights[target_loss] == 0:
if dist.is_initialized():
if dont_have_to_train:
self.unlabeled_loss_weights[target_loss] = 0
elif self.unlabeled_loss_weights[target_loss] == 0:
continue
self._update_unlabeled_loss(losses, ul_loss, ul_loss_name, self.unlabeled_loss_weights[target_loss])
return losses
Expand Down
2 changes: 0 additions & 2 deletions tests/e2e/cli/detection/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,6 @@ def test_otx_eval(self, template, tmp_dir_path):
def test_otx_multi_gpu_train_semisl(self, template, tmp_dir_path):
if not (Path(template.model_template_path).parent / "semisl").is_dir():
pytest.skip(f"Semi-SL training type isn't available for {template.name}")
if template.name == "ResNeXt101-ATSS":
pytest.skip(f"Issue#2705: multi-gpu training e2e test failure for {template.name}")
tmp_dir_path = tmp_dir_path / "detection/test_multi_gpu_semisl"
args_semisl_multigpu = copy.deepcopy(args_semisl)
args_semisl_multigpu["--gpus"] = "0,1"
Expand Down

0 comments on commit 34e0ec2

Please sign in to comment.