Skip to content

Commit

Permalink
Add train_step unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyajie committed Sep 25, 2021
1 parent 0ecffa1 commit d496e1d
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion tests/test_models/test_segmentors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,25 @@ def _segmentor_forward_train_test(segmentor):
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)
assert isinstance(losses, dict)

# Test train_step
data_batch = dict(
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
outputs = segmentor.train_step(data_batch, None)
assert isinstance(outputs, dict)
assert 'loss' in outputs
assert 'log_vars' in outputs
assert 'num_samples' in outputs

# Test val_step
with torch.no_grad():
segmentor.eval()
data_batch = dict(
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
outputs = segmentor.val_step(data_batch)
outputs = segmentor.val_step(data_batch, None)
assert isinstance(outputs, dict)
assert 'loss' in outputs
assert 'log_vars' in outputs
assert 'num_samples' in outputs

# Test forward simple test
with torch.no_grad():
Expand Down

0 comments on commit d496e1d

Please sign in to comment.