Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log wandb step using wandb native step arg in addition to the "step" key. #613

Merged
merged 4 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,17 @@ When training a RN50 on YFCC the same hyperparameters as above are used, with th

Note that to use another model, like `ViT-B/32` or `RN50x4` or `RN50x16` or `ViT-B/16`, specify with `--model RN50x4`.

### Launch tensorboard:
### Logging

For tensorboard logging, run:
```bash
tensorboard --logdir=logs/tensorboard/ --port=7777
```

For wandb logging, we recommend looking at the `step` variable instead of `Step`, since the later was not properly set in earlier versions of this codebase.
For older runs with models trained before https://github.com/mlfoundations/open_clip/pull/613, the `Step` variable should be ignored.
For newer runs, after that PR, the two variables are the same.

## Evaluation / Zero-Shot

We recommend https://github.com/LAION-AI/CLIP_benchmark#how-to-use for systematic evaluation on 40 datasets.
Expand Down
35 changes: 23 additions & 12 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,17 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist
}
log_data.update({name:val.val for name,val in losses_m.items()})

for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, 'Please install wandb.'
wandb.log({name: val, 'step': step})
log_data = {"train/" + name: val for name, val in log_data.items()}

if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, step)

if args.wandb:
assert wandb is not None, 'Please install wandb.'
log_data['step'] = step # for backwards compatibility
wandb.log(log_data, step=step)

# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
Expand Down Expand Up @@ -317,19 +320,27 @@ def evaluate(model, data, epoch, args, tb_writer=None):
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)

log_data = {"val/" + name: val for name, val in metrics.items()}

if args.save_logs:
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val/{name}", val, epoch)
if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, epoch)

with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")

if args.wandb:
assert wandb is not None, 'Please install wandb.'
for name, val in metrics.items():
wandb.log({f"val/{name}": val, 'epoch': epoch})
if 'train' in data:
dataloader = data['train'].dataloader
num_batches_per_epoch = dataloader.num_batches // args.accum_freq
step = num_batches_per_epoch * epoch
else:
step = None
log_data['epoch'] = epoch
wandb.log(log_data, step=step)

return metrics

Expand Down