From 3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c Mon Sep 17 00:00:00 2001 From: Alexander Kolesnikov <47025656+akolesnikoff@users.noreply.github.com> Date: Mon, 13 Nov 2023 14:48:10 +0100 Subject: [PATCH] update image_text proj and other misc updates (#74) --- README.md | 62 ++- .../{lit_coco.py => siglip_lit_coco.py} | 78 +-- big_vision/input_pipeline.py | 4 +- .../models/proj/image_text/two_towers.py | 20 + big_vision/models/vit.py | 43 +- big_vision/requirements.txt | 5 +- ...trastive.py => _deprecated_contrastive.py} | 0 big_vision/trainers/proj/image_text/siglip.py | 505 ++++++++++++++++++ big_vision/utils.py | 3 +- 9 files changed, 633 insertions(+), 87 deletions(-) rename big_vision/configs/proj/image_text/{lit_coco.py => siglip_lit_coco.py} (56%) rename big_vision/trainers/proj/image_text/{contrastive.py => _deprecated_contrastive.py} (100%) create mode 100644 big_vision/trainers/proj/image_text/siglip.py diff --git a/README.md b/README.md index 80ae807..ef3885b 100644 --- a/README.md +++ b/README.md @@ -284,9 +284,9 @@ recommended. Below we provide instructions on how to do it. First, create some useful variables, which we be reused: ``` -export NAME="a name of the TPU deployment, e.g. my-tpu-machine" -export ZONE="GCP geographical zone, e.g. europe-west4-a" -export GS_BUCKET_NAME="Name of the storage bucket, e.g. my_bucket" +export NAME= +export ZONE= +export GS_BUCKET_NAME= ``` The following command line will create TPU VMs with 32 cores, @@ -312,7 +312,11 @@ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "bash b We recommend preparing `tfds` data locally as described above and then uploading the data to `Google Cloud` bucket. However, if you prefer, the datasets which do not require manual downloads can be prepared automatically using a TPU -machine as described below. +machine as described below. Note that TPU machines have only 100 GB of disk +space, and multihost TPU slices do not allow for external disks to be attached +in a write mode, so the instructions below may not work for preparing large +datasets. As yet another alternative, we provide instructions +[on how to prepare `tfds` data on CPU-only GCP machine](#preparing-tfds-data-on-a-standalone-gcp-cpu-machine). Specifically, the seven TFDS datasets used during evaluations will be generated under `~/tensorflow_datasets` on TPU machine with this command: @@ -358,18 +362,64 @@ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_D ## FSDP training. `big_vision` supports flexible parameter and model sharding strategies. -Currently, we support the popular sharding strategy, name FSDP, via a simple config change, see [this config example](big_vision/configs/transfer.py). -For example, to run FSDP finetuning of a pretrained ViT-L model, run the following command (possibly adjusting batch size depending on your hardware): +Currently, we support a popular FSDP sharding via a simple config change, see [this config example](big_vision/configs/transfer.py). +For example, to run FSDP finetuning of a pretrained ViT-L model, run the following command (possible adjusting batch size depending on your hardware): ``` gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-l/16,dataset=oxford_iiit_pet,crop=resmall_crop,fsdp=True,batch_size=256 --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03" ``` +## Image-text training with SigLIP. + +A minimal example that uses public `coco` captions data: + +``` +gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.trainers.proj.image_text.siglip --config big_vision/configs/proj/image_text/siglip_lit_coco.py --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%Y-%m-%d_%H%M'`" +``` + + + ## Sometimes useful gcloud commands - Destroy the TPU machines: `gcloud compute tpus tpu-vm delete $NAME --zone $ZONE` - Remove all big_vision-related folders on all hosts: `gcloud compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'rm -rf ~/big_vision ~/bv_venv'` +## Preparing `tfds` data on a standalone GCP CPU machine. + +First create a new machine and a disk (feel free to adjust exact machine type and disk settings/capacity): + +``` +export NAME_CPU_HOST= +export NAME_DISK= +gcloud compute instances create $NAME_CPU_HOST --machine-type c3-standard-22 --zone $ZONE --image-family ubuntu-2204-lts --image-project ubuntu-os-cloud +gcloud compute disks create $NAME_DISK --size 1000GB --zone $ZONE --type pd-balanced +``` + +Now attach the disk to the newly create machine: + +``` +gcloud compute instances attach-disk $NAME_CPU_HOST --disk $NAME_DISK --zone $ZONE +``` + +Next, `ssh` to the machine `gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE` and +[follow instructions to format and mount the disk](https://cloud.google.com/compute/docs/disks/format-mount-disk-linux). +Let's assume it was mounted to `/mnt/disks/tfds`. + +Almost there, now clone and set up `big_vision`: + +``` +gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "git clone https://github.com/google-research/big_vision.git && cd big_vision && sh big_vision/run_tpu.sh" +``` + +Finally, prepare the dataset (e.g. `coco_captions`) using the utility script and +copy the result to you google cloud bucket: + +``` +gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "cd big_vision && TFDS_DATA_DIR=/mnt/disks/tfds/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets coco_captions" +gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "rm -rf /mnt/disks/tfds/tensorflow_datasets/downloads && gsutil cp -r /mnt/disks/tfds/tensorflow_datasets gs://$GS_BUCKET_NAME" +``` + + # ViT baseline We provide a well-tuned ViT-S/16 baseline in the config file named diff --git a/big_vision/configs/proj/image_text/lit_coco.py b/big_vision/configs/proj/image_text/siglip_lit_coco.py similarity index 56% rename from big_vision/configs/proj/image_text/lit_coco.py rename to big_vision/configs/proj/image_text/siglip_lit_coco.py index dd73a2a..d22dd9d 100644 --- a/big_vision/configs/proj/image_text/lit_coco.py +++ b/big_vision/configs/proj/image_text/siglip_lit_coco.py @@ -13,34 +13,13 @@ # limitations under the License. # pylint: disable=line-too-long -r"""Trains a LiT model as in https://arxiv.org/abs/2111.07991 - -IMPORTANT NOTE: This config uses coco_captions for demonstration purposes. As of -6/17/22 neither YFCC100M nor CC12M are available in TFDS. We're working on -publishing these datasets to allow for full replication of the numbers reported -in the paper. - -Published models: - -https://github.com/google-research/vision_transformer#lit-models - -Colab to load public LiT models: -https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb - -gs://vit_models/lit/LiT-B16B.npz - 72.07% i1k 0shot -gs://vit_models/lit/LiT-L16L.npz - 75.68% i1k 0shot - missing in publication +r"""Minimal SigLIP (https://arxiv.org/abs/2303.15343) example. Example training: -big_vision.trainers.proj.image_text.contrastive \ - --config big_vision/configs/proj/image_text/lit_coco.py \ - --workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'` - -Example evaluation: - -big_vision.tools.eval_only \ - --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_base,img_head,img=B/16,init=gs://vit_models/lit/LiT-B16B.npz \ - --workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'` +big_vision.trainers.proj.image_text.siglip \ + --config big_vision/configs/proj/image_text/lit_coco.py:batch_size=512 \ + --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%Y-%m-%d_%H%M'` """ import big_vision.configs.common as bvcc @@ -52,14 +31,14 @@ def get_config(arg=None): """The base configuration.""" arg = bvcc.parse_arg( arg, res=224, runlocal=False, token_len=16, txt='bert_base', img='B/16', - init='', img_head=False) + init='', img_head=False, batch_size=512) img_name, img_init = common.inits[arg.img] txt_name, txt_init = common.inits[arg.txt] config = ConfigDict() config.input = {} config.input.data = dict(name='coco_captions', split='train') - config.input.batch_size = 4096 if not arg.runlocal else 32 + config.input.batch_size = arg.batch_size if not arg.runlocal else 32 config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50 config.total_steps = 5_000 if not arg.runlocal else 1 @@ -78,11 +57,6 @@ def get_config(arg=None): f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)' f'|flatten|{tokenizer("captions/text")}|keep("image", "labels")' ) - pp_eval = ( - f'decode|resize({arg.res})|value_range(-1,1)' - f'|flatten|{tokenizer("captions/text")}' - '|keep("image", "labels")' - ) config.pp_modules = [ 'ops_general', 'ops_image', 'ops_text', 'proj.flaxformer.bert_ops'] @@ -114,6 +88,7 @@ def get_config(arg=None): config.model.temperature_init = 10.0 dim = {'B': 768, 'L': 1024}[arg.img[0]] config.model.out_dim = (dim if arg.img_head else None, dim) # (image_out_dim, text_out_dim) + config.model.bias_init = -2.71 if txt_name == 'base': config.optax_name = 'scale_by_adam' @@ -130,48 +105,11 @@ def get_config(arg=None): config.grad_clip_norm = 1.0 - # Eval section (Both few-shot and zero-shot) - eval_common = dict( - type='proj.image_text.contrastive', - use_global_batch=True, - log_steps=500 if not arg.runlocal else 5, - ) config.evals = {} - sub = '[:4]' if arg.runlocal else '' - config.evals.val = { - **eval_common, - 'data': dict(name=config.input.data.name, split=f'val{sub}'), - 'pp_fn': pp_eval, - } - config.evals.coco = { - **eval_common, - 'data': dict(name='coco_captions', split=f'val{sub}'), - 'pp_fn': ( - f'decode|resize({arg.res})|value_range(-1,1)' - f'|flatten|{tokenizer("captions/text")}|keep("image", "labels")'), - } - config.evals.imagenet = { - **eval_common, - 'data': dict(name='imagenet2012', split=f'validation{sub}'), - 'pp_fn': ( - f'decode|resize({arg.res})|value_range(-1,1)' - '|clip_i1k_label_names' - f'|{tokenizer("labels")}|keep("image", "labels")'), - } - - config.evals.disclf = {} - config.evals.disclf.pp_img = f'resize({arg.res})|value_range(-1,1)' - config.evals.disclf.pp_txt = tokenizer('texts') - config.evals.disclf.type = 'proj.image_text.discriminative_classifier' - config.evals.disclf.prefix = 'z/0shot/' - config.evals.disclf.log_steps = eval_common['log_steps'] config.evals.retrieval_coco = common.get_coco( pp_img=f'resize({arg.res})|value_range(-1, 1)', pp_txt=tokenizer('texts'), - log_steps=config.evals.disclf.log_steps, + log_steps=1000, ) - config.seed = 0 - config.l = config.m = 0 - return config diff --git a/big_vision/input_pipeline.py b/big_vision/input_pipeline.py index d9e1f45..53d911f 100644 --- a/big_vision/input_pipeline.py +++ b/big_vision/input_pipeline.py @@ -222,7 +222,7 @@ def _shard(x): sharding = NamedSharding(mesh, P("devices")) local_ds = mesh.local_devices - x = np.asarray(memoryview(x)) # No-copy: http://shortn/_KM5whIEtWI + x = np.asarray(memoryview(x)) # No-copy: http://(internal link) xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds) global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:]) @@ -237,7 +237,7 @@ def _shard(x): def shard_and_put(x, shard=True, put=True): - x = np.asarray(memoryview(x)) # No-copy conversion: http://shortn/_KM5whIEtWI + x = np.asarray(memoryview(x)) # No-copy conversion: http://(internal link) if shard: x = einops.rearrange(x, "(d l) ... -> d l ...", d=jax.local_device_count()) if shard and put: # Only works for pmap (for now). diff --git a/big_vision/models/proj/image_text/two_towers.py b/big_vision/models/proj/image_text/two_towers.py index 71466ef..706b9c7 100644 --- a/big_vision/models/proj/image_text/two_towers.py +++ b/big_vision/models/proj/image_text/two_towers.py @@ -92,6 +92,9 @@ def __call__(self, image, text=None, **kw): def load(init_params, init_files, model_cfg, img_load_kw={}, txt_load_kw={}): # pylint: disable=dangerous-default-value """Loads both towers, `init_files` is now a dict with `img` and `txt` keys.""" + if isinstance(init_files, str): + init_files = VANITY_NAMES.get(init_files, init_files) + if isinstance(init_files, str): # A shortcut for a single file checkpoint of a two_towers model. if "bias_init" in model_cfg.keys(): @@ -132,3 +135,20 @@ def load(init_params, init_files, model_cfg, img_load_kw={}, txt_load_kw={}): # f"a typo. Here it is: {init_files}") return restored_params + + +# Shortcut names for some canonical paper checkpoints: +VANITY_NAMES = { + # pylint: disable=line-too-long + # SigLIP image encoder checkpoints from https://arxiv.org/abs/2303.15343 + "SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz", + "SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz", + "SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz", + "SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz", + "SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz", + "SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz", + "SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz", + "SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz", + "SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz", + # pylint: enable=line-too-long +} diff --git a/big_vision/models/vit.py b/big_vision/models/vit.py index 649bbd1..de0ca77 100644 --- a/big_vision/models/vit.py +++ b/big_vision/models/vit.py @@ -379,23 +379,45 @@ def stack(*values): return params_scan +def scan_to_pyloop(params_scan): + """Converts a lax.scan ViT checkpoint to a python for-loop based one.""" + # See comment in pyloop_to_scan. + + params_scan = jax.tree_map(lambda x: x, params_scan) # Structural copy + t = params_scan["Transformer"] + + # Find out how many encoderblocks there are + depth = len(t["encoderblock"]["LayerNorm_0"]["bias"]) + + # Create that many encoderblocks, each with their slice of their sub-pytree. + for lyr in range(depth): + block = jax.tree_map(lambda x, lyr=lyr: x[lyr], t["encoderblock"]) + t[f"encoderblock_{lyr}"] = block + + del t["encoderblock"] + return params_scan + + def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name because we had to CamelCase above. """Load init from checkpoint, both old model and this one. +Hi-res posemb.""" - del model_cfg - init_file = VANITY_NAMES.get(init_file, init_file) restored_params = utils.load_params(init_file) restored_params = fix_old_checkpoints(restored_params) - if init_params and "encoderblock" in init_params["Transformer"]: + # Detect attempts to load non-scan checkpoint into scan model. + if (model_cfg.get("scan") and + "encoderblock" not in restored_params["Transformer"]): restored_params = pyloop_to_scan(restored_params) - # TODO: detect and convert the other way around too. + if (not model_cfg.get("scan") + and "encoderblock" in restored_params["Transformer"]): + restored_params = scan_to_pyloop(restored_params) # possibly use the random init for some of the params (such as, the head). restored_params = common.merge_params(restored_params, init_params, dont_load) # resample posemb if needed. + # TODO: Take this from model_cfg to avoid need for init_params. if init_params and "pos_embedding" in init_params: restored_params["pos_embedding"] = resample_posemb( old=restored_params["pos_embedding"], @@ -406,7 +428,6 @@ def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=in # Shortcut names for some canonical paper checkpoints: VANITY_NAMES = { - # pylint: disable=line-too-long # pylint: disable=line-too-long # Recommended models from https://arxiv.org/abs/2106.10270 # Many more models at https://github.com/google-research/vision_transformer @@ -437,6 +458,16 @@ def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=in "deit3_L_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_224_21k.npz", "deit3_L_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_1k.npz", "deit3_L_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_21k.npz", - # pylint: disable=line-too-long + + # SigLIP image encoder checkpoints from https://arxiv.org/abs/2303.15343 + "SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz:img", + "SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz:img", + "SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz:img", + "SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz:img", + "SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz:img", + "SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz:img", + "SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz:img", + "SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz:img", + "SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz:img", # pylint: enable=line-too-long } diff --git a/big_vision/requirements.txt b/big_vision/requirements.txt index 2f5cc50..5d50f92 100644 --- a/big_vision/requirements.txt +++ b/big_vision/requirements.txt @@ -1,12 +1,13 @@ +numpy>=1.26 absl-py -clu +git+https://github.com/google/CommonLoopUtils einops flax optax git+https://github.com/google/flaxformer git+https://github.com/akolesnikoff/panopticapi.git@mute overrides -tensorflow +tensorflow-cpu tfds-nightly tensorflow-addons tensorflow-text diff --git a/big_vision/trainers/proj/image_text/contrastive.py b/big_vision/trainers/proj/image_text/_deprecated_contrastive.py similarity index 100% rename from big_vision/trainers/proj/image_text/contrastive.py rename to big_vision/trainers/proj/image_text/_deprecated_contrastive.py diff --git a/big_vision/trainers/proj/image_text/siglip.py b/big_vision/trainers/proj/image_text/siglip.py new file mode 100644 index 0000000..eac5af5 --- /dev/null +++ b/big_vision/trainers/proj/image_text/siglip.py @@ -0,0 +1,505 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trainer for "Sigmoid Loss for Language Image Pre-Training". + +SigLIP (https://arxiv.org/abs/2303.15343) + +TODO: implement chunked version with shard_map. +""" +# pylint: disable=consider-using-from-import +# pylint: disable=logging-fstring-interpolation + +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.sharding as bv_sharding +import big_vision.utils as u +from clu import parameter_overview +import flax +import flax.linen as nn +import jax +from jax.experimental import mesh_utils +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() +# Transfer guard will fail the program whenever that data between a host and +# a device is transferred implicitly. This often catches subtle bugs that +# cause slowdowns and memory fragmentation. Explicit transfers are done +# with jax.device_put and jax.device_get. +jax.config.update("jax_transfer_guard", "disallow") +# Fixes design flaw in jax.random that may cause unnecessary d2d comms. +jax.config.update("jax_threefry_partitionable", True) + + +def main(argv): + del argv + + jax.distributed.initialize() + + # Make sure TF does not touch GPUs. + tf.config.set_visible_devices([], "GPU") + + # Consistent device order is important to ensure correctness of various train + # loop components, such as input pipeline, update step, evaluators. We use + # jax utils to infer device order that will be used throughout the program. + devices = mesh_utils.create_device_mesh((jax.device_count(),)) + +################################################################################ +# # +# Set up logging # +# # +################################################################################ + + # Set up work directory and print welcome message. + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.bv") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # Setup up logging and experiment manager. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + +################################################################################ +# # +# Input Pipeline # +# # +################################################################################ + + write_note("Initializing train dataset...") + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + train_ds, ntrain_img = input_pipeline.training(config.input) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size, + measure=mw.measure, write_note=write_note) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Start input pipeline as early as possible. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_global(train_ds, devices, n_prefetch) + +################################################################################ +# # +# Create Model & Optimizer # +# # +################################################################################ + + write_note("Creating model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model(**config.get("model", {})) + + def init(rng): + bs = batch_size // jax.device_count() + image_size = tuple(train_ds.element_spec["image"].shape[1:]) + no_image = jnp.zeros((bs,) + image_size, jnp.float32) + text_size = tuple(train_ds.element_spec["labels"].shape[1:]) + no_text = jnp.zeros((bs,) + text_size, jnp.int32) + params = flax.core.unfreeze(model.init(rng, no_image, no_text))["params"] + + # Set bias in the head to a low value, such that loss is small initially. + if "init_head_bias" in config: + params["head"]["bias"] = jnp.full_like(params["head"]["bias"], + config["init_head_bias"]) + + return params + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0))) + + write_note("Inferring parameter shapes...") + rng, rng_init = jax.random.split(rng) + params_shape = jax.eval_shape(init, rng_init) + + write_note("Inferring optimizer state shapes...") + tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + opt_shape = jax.eval_shape(tx.init, params_shape) + # We jit this, such that the arrays are created on the CPU, not device[0]. + sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns] + + if jax.process_index() == 0: + num_params = sum(np.prod(p.shape) for p in jax.tree_leaves(params_shape)) + mw.measure("num_params", num_params) + +################################################################################ +# # +# Shard & Transfer # +# # +################################################################################ + + # Currently we support a simple 1D mesh only, this may change in the future. + # 1D mesh is sufficient to implement data-parallel training, ZERO2 and + # fully-sharded data-parallel (fsdp) training. + write_note("Creating device mesh...") + mesh = jax.sharding.Mesh(devices, ("data",)) + repl_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + write_note("Inferring shardings...") + params_sharding = bv_sharding.infer_sharding( + params_shape, mesh, axis_name="data", + strategy=config.get("param_sharding", "replicated"), + extra_strategy_args=config.get("param_sharding_args", {})) + opt_sharding = bv_sharding.infer_sharding( + opt_shape, mesh, axis_name="data", + strategy=config.get("optim_sharding", "replicated"), + extra_strategy_args=config.get("optim_sharding_args", {})) + + write_note("Transferring train_state to devices...") + # RNG is always replicated + rng_init = u.reshard(rng_init, repl_sharding) + + # Parameters and the optimizer are now global (distributed) jax arrays. + params = jax.jit(init, out_shardings=params_sharding)(rng_init) + opt = jax.jit(tx.init, out_shardings=opt_sharding)(params) + + rng, rng_loop = jax.random.split(rng, 2) + rng_loop = u.reshard(rng_loop, repl_sharding) + del rng # not used anymore, so delete it. + + # At this point we have everything we need to form a train state. It contains + # all the parameters that are passed and updated by the main training step. + train_state_sharding = { + "params": params_sharding, "opt": opt_sharding, "rng": repl_sharding} + train_state = { + "params": params, "opt": opt, "rng": rng_loop} + del params, opt, rng_loop # Delete to avoid memory leak or accidental reuse. + + write_note("Logging parameter overview...") + parameter_overview.log_parameter_overview( + train_state["params"], msg="Init params", + include_stats="global", jax_logging_process=0) + +################################################################################ +# # +# Update Step # +# # +################################################################################ + + @functools.partial( + jax.jit, + donate_argnums=(0,), + out_shardings=(train_state_sharding, repl_sharding)) + def update_fn(train_state, batch): + """Update step.""" + + images, labels = batch["image"], batch["labels"] + + rng = train_state["rng"] + assert "mixup" not in config, "Mixup is not supported for SigLIP." + + # Get device-specific loss rng. + rng, rng_model = jax.random.split(rng, 2) + + def loss_fn(params): + zimg, ztxt, extras = model.apply( + {"params": params}, images, labels, + train=True, rngs={"dropout": rng_model}) + logits = jnp.dot(zimg, ztxt.T) + logits = logits * extras["t"] + extras["b"] + eye = jnp.eye(zimg.shape[0]) + + # Standard sigmoid computes everything twice, once assuming positive + # labels and once assuming negative ones. But here we know exactly where + # to find positives (on "me" diagonal) and negatives (everywhere else), + # so compute each one's loss only once: + m1_diag1 = -jnp.ones_like(logits) + 2 * eye + loglik = jax.nn.log_sigmoid(m1_diag1 * logits) + + # Normalize by npos per column, but that's one, so just sum. + nll = -jnp.sum(loglik, axis=-1) + + # NOTE: same as concat'ing me/ot along axis -1 above. + l = jnp.mean(nll) + + return l + + params, opt = train_state["params"], train_state["opt"] + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + measurements = {"training_loss": loss} + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.sum(g * g) for g in gs])) + ps = jax.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.sum(p * p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.sum(u * u) for u in us])) + + return {"params": params, "opt": opt, "rng": rng}, measurements + +################################################################################ +# # +# Load Checkpoint # +# # +################################################################################ + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + + ckpt_mngr = None + if save_ckpt_path or resume_ckpt_path: + ckpt_mngr = array_serial.GlobalAsyncCheckpointManager() + + if resume_ckpt_path: + write_note(f"Resuming training from checkpoint {resume_ckpt_path}...") + shardings = { + **train_state_sharding, + "chrono": jax.tree_map(lambda _: repl_sharding, + u.chrono.save()), + } + loaded = u.load_checkpoint_ts( + resume_ckpt_path, tree=shardings, shardings=shardings) + train_state = {key: loaded[key] for key in train_state.keys()} + + u.chrono.load(jax.device_get(loaded["chrono"])) + del loaded + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + # TODO: when updating the `load` API soon, do pass and request the + # full `train_state` from it. Examples where useful: VQVAE, BN. + train_state["params"] = model_mod.load( + train_state["params"], config.model_init, config.get("model"), + **config.get("model_load", {})) + + # load has the freedom to return params not correctly sharded. Think of for + # example ViT resampling position embedings on CPU as numpy arrays. + train_state["params"] = u.reshard( + train_state["params"], train_state_sharding["params"]) + + parameter_overview.log_parameter_overview( + train_state["params"], msg="restored params", + include_stats="global", jax_logging_process=0) + + +################################################################################ +# # +# Setup Evals # +# # +################################################################################ + + # We do not jit/pmap this function, because it is passed to evaluator that + # does it later. We output as many intermediate tensors as possible for + # maximal flexibility. Later `jit` will prune out things that are not needed. + def eval_logits_fn(train_state, batch): + zimg, ztxt, out = model.apply( + {"params": train_state["params"]}, + batch.get("image", None), batch.get("labels", None)) + return zimg, ztxt, out + + def eval_loss_fn(train_state, batch): + logits, _ = model.apply({"params": train_state["params"]}, batch["image"]) + loss_fn = getattr(u, config.get("loss", "sigmoid_xent")) + return { + "loss": loss_fn(logits=logits, labels=batch["labels"], reduction=False) + } + + eval_fns = { + "predict": eval_logits_fn, + "loss": eval_loss_fn, + } + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, eval_fns, + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + devices, + ) + + # At this point we need to know the current step to see whether to run evals. + write_note("Inferring the first step number...") + first_step_device = bv_optax.get_count(train_state["opt"], jittable=True) + first_step = int(jax.device_get(first_step_device)) + u.chrono.inform(first_step=first_step) + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + write_note("Running initial or final evals...") + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules([("act_batch", "data")]): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", value) + +################################################################################ +# # +# Train Loop # +# # +################################################################################ + + prof = None # Keeps track of start/stop of profiler state. + + write_note("Starting training loop, compiling the first step...") + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + with mesh, nn.logical_axis_rules([("act_batch", "data")]): + train_state, measurements = update_fn(train_state, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", + sched_fn_cpu(u.put_cpu(step - 1))) + measurements = jax.device_get(measurements) + for name, value in measurements.items(): + mw.measure(name, value) + u.chrono.tick(step) + if not np.isfinite(measurements["training_loss"]): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + keep_ckpt_steps = get_steps("keep_ckpt", None) or total_steps + if save_ckpt_path and ( + (keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False)) + or u.itstime(step, get_steps("ckpt", None), total_steps, first=True) + ): + u.chrono.pause(wait_for=train_state) + + # Copy because we add extra stuff to the checkpoint. + ckpt = {**train_state} + + # To save chrono state correctly and safely in a multihost setup, we + # broadcast the state to all hosts and convert it to a global array. + with jax.transfer_guard("allow"): + chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save()) + chrono_shardings = jax.tree_map(lambda _: repl_sharding, chrono_ckpt) + ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)} + + u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=train_state) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules([("act_batch", "data")]): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", jax.device_get(value)) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/utils.py b/big_vision/utils.py index 8955488..23aabc3 100644 --- a/big_vision/utils.py +++ b/big_vision/utils.py @@ -210,7 +210,8 @@ def load_params(ckpt, **kw): else: # Here we're now loading new-style tensorstore checkpoints. # We can be a more efficient and load params and `key` only right away. - checkpoint = load_checkpoint_ts(ckpt, regex=f"params/{key}/.*") + regex = f"params/{key}/.*" if key else "params/.*" + checkpoint = load_checkpoint_ts(ckpt, regex=regex) params = checkpoint["params"] if key is not None: