Skip to content

Commit

Permalink
merge train/val_step and batch_processor into run_iter
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Sep 14, 2020
1 parent a0cc5a8 commit 1ffafc4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 26 deletions.
5 changes: 5 additions & 0 deletions mmcv/runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self,
self.optimizer = optimizer
self.logger = logger
self.meta = meta
self.train_mode = False

# create work_dir
if mmcv.is_str(work_dir):
Expand Down Expand Up @@ -172,6 +173,10 @@ def max_iters(self):
"""int: Maximum training iterations."""
return self._max_iters

@abstractmethod
def run_iter(self, data_batch, **kwargs):
pass

@abstractmethod
def train(self):
pass
Expand Down
46 changes: 20 additions & 26 deletions mmcv/runner/epoch_based_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,34 @@ class EpochBasedRunner(BaseRunner):
This runner train models epoch by epoch.
"""

def run_iter(self, data_batch, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=self.train_mode, **kwargs)
elif self.train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs

def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.train_model = True
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
if self.batch_processor is None:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.run_iter()
self.call_hook('after_train_iter')
self._iter += 1

Expand All @@ -49,26 +54,15 @@ def train(self, data_loader, **kwargs):
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.train_model = False
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
if self.batch_processor is None:
outputs = self.model.val_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.val_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.run_iter()
self.call_hook('after_val_iter')

self.call_hook('after_val_epoch')
Expand Down

0 comments on commit 1ffafc4

Please sign in to comment.