Skip to content

Commit

Permalink
Log wandb step using wandb native step arg in addition to the "step" …
Browse files Browse the repository at this point in the history
…key. (#613)

* wandb step fix

* backwards compat fix

* update wandb calls

* update readme
  • Loading branch information
gabrielilharco authored Oct 11, 2023
1 parent 0142d27 commit 4ccb752
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,17 @@ When training a RN50 on YFCC the same hyperparameters as above are used, with th
Note that to use another model, like `ViT-B/32` or `RN50x4` or `RN50x16` or `ViT-B/16`, specify with `--model RN50x4`.
### Launch tensorboard:
### Logging
For tensorboard logging, run:
```bash
tensorboard --logdir=logs/tensorboard/ --port=7777
```
For wandb logging, we recommend looking at the `step` variable instead of `Step`, since the later was not properly set in earlier versions of this codebase.
For older runs with models trained before https://github.com/mlfoundations/open_clip/pull/613, the `Step` variable should be ignored.
For newer runs, after that PR, the two variables are the same.
## Evaluation / Zero-Shot
We recommend https://github.com/LAION-AI/CLIP_benchmark#how-to-use for systematic evaluation on 40 datasets.
Expand Down
35 changes: 23 additions & 12 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,17 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist
}
log_data.update({name:val.val for name,val in losses_m.items()})

for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, 'Please install wandb.'
wandb.log({name: val, 'step': step})
log_data = {"train/" + name: val for name, val in log_data.items()}

if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, step)

if args.wandb:
assert wandb is not None, 'Please install wandb.'
log_data['step'] = step # for backwards compatibility
wandb.log(log_data, step=step)

# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
Expand Down Expand Up @@ -329,19 +332,27 @@ def evaluate(model, data, epoch, args, tb_writer=None):
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)

log_data = {"val/" + name: val for name, val in metrics.items()}

if args.save_logs:
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val/{name}", val, epoch)
if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, epoch)

with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")

if args.wandb:
assert wandb is not None, 'Please install wandb.'
for name, val in metrics.items():
wandb.log({f"val/{name}": val, 'epoch': epoch})
if 'train' in data:
dataloader = data['train'].dataloader
num_batches_per_epoch = dataloader.num_batches // args.accum_freq
step = num_batches_per_epoch * epoch
else:
step = None
log_data['epoch'] = epoch
wandb.log(log_data, step=step)

return metrics

Expand Down

0 comments on commit 4ccb752

Please sign in to comment.