Skip to content

Commit

Permalink
Merge pull request #115 from hrukalive/refactor-v2
Browse files Browse the repository at this point in the history
Remove several warning and better floating point representation on the progress bar
  • Loading branch information
hrukalive authored Jul 16, 2023
2 parents 03c1bdf + cdb0357 commit 02eeb84
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
15 changes: 8 additions & 7 deletions basics/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,18 @@ def _training_step(self, sample):
"""
losses = self.run_model(sample)
total_loss = sum(losses.values())
return total_loss, {**losses, 'batch_size': sample['size']}
return total_loss, {**losses, 'batch_size': float(sample['size'])}

def training_step(self, sample, batch_idx, optimizer_idx=-1):
total_loss, log_outputs = self._training_step(sample)

# logs to progress bar
self.log_dict(log_outputs, prog_bar=True, logger=False, on_step=True, on_epoch=False)
self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False, on_step=True, on_epoch=False)
self.log('lr', self.lr_schedulers().get_last_lr()[0], prog_bar=True, logger=False, on_step=True, on_epoch=False)
# logs to tensorboard
tb_log = {f'tr/{k}': v for k, v in log_outputs.items()}
if self.global_step % self.trainer.log_every_n_steps == 0:
if self.global_step % hparams['log_interval'] == 0:
tb_log = {f'tr/{k}': v for k, v in log_outputs.items()}
tb_log['tr/lr'] = self.lr_schedulers().get_last_lr()[0]
self.logger.log_metrics(tb_log, step=self.global_step)

return total_loss
Expand Down Expand Up @@ -188,7 +189,7 @@ def on_validation_epoch_end(self):
self.skip_immediate_ckpt_save = True
return
metric_vals = {k: v.compute() for k, v in self.valid_metrics.items()}
self.log('val_loss', metric_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False)
self.log('val_loss', metric_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True)
self.logger.log_metrics({f'val/{k}': v for k, v in metric_vals.items()}, step=self.global_step)
for metric in self.valid_metrics.values():
metric.reset()
Expand Down Expand Up @@ -318,7 +319,7 @@ def start(cls):
permanent_ckpt_interval=hparams['permanent_ckpt_interval'],
verbose=True
),
LearningRateMonitor(logging_interval='step'),
# LearningRateMonitor(logging_interval='step'),
DsTQDMProgressBar(),
],
logger=TensorBoardLogger(
Expand All @@ -330,7 +331,7 @@ def start(cls):
val_check_interval=hparams['val_check_interval'] * hparams['accumulate_grad_batches'],
# so this is global_steps
check_val_every_n_epoch=None,
log_every_n_steps=hparams['log_interval'],
log_every_n_steps=1,
max_steps=hparams['max_updates'],
use_distributed_sampler=False,
num_sanity_val_steps=hparams['num_sanity_val_steps'],
Expand Down
11 changes: 9 additions & 2 deletions utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,15 @@ def get_metrics(self, trainer, model):
items['steps'] = str(trainer.global_step)
for k, v in items.items():
if isinstance(v, float):
if 0.00001 <= v < 10:
items[k] = f"{v:.5f}"
if 0.001 <= v < 10:
items[k] = np.format_float_positional(v, unique=True, precision=5, trim='-')
elif 0.00001 <= v < 0.001:
if len(np.format_float_positional(v, unique=True, precision=8, trim='-')) > 8:
items[k] = np.format_float_scientific(v, precision=3, unique=True, min_digits=2, trim='-')
else:
items[k] = np.format_float_positional(v, unique=True, precision=5, trim='-')
elif v < 0.00001:
items[k] = np.format_float_scientific(v, precision=3, unique=True, min_digits=2, trim='-')
items.pop("v_num", None)
return items

Expand Down

0 comments on commit 02eeb84

Please sign in to comment.