Skip to content

Commit

Permalink
update image_text proj and other misc updates (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
akolesnikoff authored Nov 13, 2023
1 parent 5ef56fb commit 3b8e5ab
Show file tree
Hide file tree
Showing 9 changed files with 633 additions and 87 deletions.
62 changes: 56 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ recommended. Below we provide instructions on how to do it.
First, create some useful variables, which we be reused:

```
export NAME="a name of the TPU deployment, e.g. my-tpu-machine"
export ZONE="GCP geographical zone, e.g. europe-west4-a"
export GS_BUCKET_NAME="Name of the storage bucket, e.g. my_bucket"
export NAME=<a name of the TPU deployment, e.g. my-tpu-machine>
export ZONE=<GCP geographical zone, e.g. europe-west4-a>
export GS_BUCKET_NAME=<Name of the storage bucket, e.g. my_bucket>
```

The following command line will create TPU VMs with 32 cores,
Expand All @@ -312,7 +312,11 @@ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "bash b
We recommend preparing `tfds` data locally as described above and then uploading
the data to `Google Cloud` bucket. However, if you prefer, the datasets which
do not require manual downloads can be prepared automatically using a TPU
machine as described below.
machine as described below. Note that TPU machines have only 100 GB of disk
space, and multihost TPU slices do not allow for external disks to be attached
in a write mode, so the instructions below may not work for preparing large
datasets. As yet another alternative, we provide instructions
[on how to prepare `tfds` data on CPU-only GCP machine](#preparing-tfds-data-on-a-standalone-gcp-cpu-machine).

Specifically, the seven TFDS datasets used during evaluations will be generated
under `~/tensorflow_datasets` on TPU machine with this command:
Expand Down Expand Up @@ -358,18 +362,64 @@ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_D
## FSDP training.

`big_vision` supports flexible parameter and model sharding strategies.
Currently, we support the popular sharding strategy, name FSDP, via a simple config change, see [this config example](big_vision/configs/transfer.py).
For example, to run FSDP finetuning of a pretrained ViT-L model, run the following command (possibly adjusting batch size depending on your hardware):
Currently, we support a popular FSDP sharding via a simple config change, see [this config example](big_vision/configs/transfer.py).
For example, to run FSDP finetuning of a pretrained ViT-L model, run the following command (possible adjusting batch size depending on your hardware):

```
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-l/16,dataset=oxford_iiit_pet,crop=resmall_crop,fsdp=True,batch_size=256 --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03"
```

## Image-text training with SigLIP.

A minimal example that uses public `coco` captions data:

```
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.trainers.proj.image_text.siglip --config big_vision/configs/proj/image_text/siglip_lit_coco.py --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%Y-%m-%d_%H%M'`"
```



## Sometimes useful gcloud commands

- Destroy the TPU machines: `gcloud compute tpus tpu-vm delete $NAME --zone $ZONE`
- Remove all big_vision-related folders on all hosts: `gcloud compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'rm -rf ~/big_vision ~/bv_venv'`

## Preparing `tfds` data on a standalone GCP CPU machine.

First create a new machine and a disk (feel free to adjust exact machine type and disk settings/capacity):

```
export NAME_CPU_HOST=<A name of a CPU-only machine>
export NAME_DISK=<A name of a disk>
gcloud compute instances create $NAME_CPU_HOST --machine-type c3-standard-22 --zone $ZONE --image-family ubuntu-2204-lts --image-project ubuntu-os-cloud
gcloud compute disks create $NAME_DISK --size 1000GB --zone $ZONE --type pd-balanced
```

Now attach the disk to the newly create machine:

```
gcloud compute instances attach-disk $NAME_CPU_HOST --disk $NAME_DISK --zone $ZONE
```

Next, `ssh` to the machine `gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE` and
[follow instructions to format and mount the disk](https://cloud.google.com/compute/docs/disks/format-mount-disk-linux).
Let's assume it was mounted to `/mnt/disks/tfds`.

Almost there, now clone and set up `big_vision`:

```
gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "git clone https://github.com/google-research/big_vision.git && cd big_vision && sh big_vision/run_tpu.sh"
```

Finally, prepare the dataset (e.g. `coco_captions`) using the utility script and
copy the result to you google cloud bucket:

```
gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "cd big_vision && TFDS_DATA_DIR=/mnt/disks/tfds/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets coco_captions"
gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "rm -rf /mnt/disks/tfds/tensorflow_datasets/downloads && gsutil cp -r /mnt/disks/tfds/tensorflow_datasets gs://$GS_BUCKET_NAME"
```


# ViT baseline

We provide a well-tuned ViT-S/16 baseline in the config file named
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,13 @@
# limitations under the License.

# pylint: disable=line-too-long
r"""Trains a LiT model as in https://arxiv.org/abs/2111.07991
IMPORTANT NOTE: This config uses coco_captions for demonstration purposes. As of
6/17/22 neither YFCC100M nor CC12M are available in TFDS. We're working on
publishing these datasets to allow for full replication of the numbers reported
in the paper.
Published models:
https://github.com/google-research/vision_transformer#lit-models
Colab to load public LiT models:
https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb
gs://vit_models/lit/LiT-B16B.npz - 72.07% i1k 0shot
gs://vit_models/lit/LiT-L16L.npz - 75.68% i1k 0shot - missing in publication
r"""Minimal SigLIP (https://arxiv.org/abs/2303.15343) example.
Example training:
big_vision.trainers.proj.image_text.contrastive \
--config big_vision/configs/proj/image_text/lit_coco.py \
--workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'`
Example evaluation:
big_vision.tools.eval_only \
--config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_base,img_head,img=B/16,init=gs://vit_models/lit/LiT-B16B.npz \
--workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'`
big_vision.trainers.proj.image_text.siglip \
--config big_vision/configs/proj/image_text/lit_coco.py:batch_size=512 \
--workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%Y-%m-%d_%H%M'`
"""

import big_vision.configs.common as bvcc
Expand All @@ -52,14 +31,14 @@ def get_config(arg=None):
"""The base configuration."""
arg = bvcc.parse_arg(
arg, res=224, runlocal=False, token_len=16, txt='bert_base', img='B/16',
init='', img_head=False)
init='', img_head=False, batch_size=512)
img_name, img_init = common.inits[arg.img]
txt_name, txt_init = common.inits[arg.txt]
config = ConfigDict()

config.input = {}
config.input.data = dict(name='coco_captions', split='train')
config.input.batch_size = 4096 if not arg.runlocal else 32
config.input.batch_size = arg.batch_size if not arg.runlocal else 32
config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50

config.total_steps = 5_000 if not arg.runlocal else 1
Expand All @@ -78,11 +57,6 @@ def get_config(arg=None):
f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)'
f'|flatten|{tokenizer("captions/text")}|keep("image", "labels")'
)
pp_eval = (
f'decode|resize({arg.res})|value_range(-1,1)'
f'|flatten|{tokenizer("captions/text")}'
'|keep("image", "labels")'
)
config.pp_modules = [
'ops_general', 'ops_image', 'ops_text', 'proj.flaxformer.bert_ops']

Expand Down Expand Up @@ -114,6 +88,7 @@ def get_config(arg=None):
config.model.temperature_init = 10.0
dim = {'B': 768, 'L': 1024}[arg.img[0]]
config.model.out_dim = (dim if arg.img_head else None, dim) # (image_out_dim, text_out_dim)
config.model.bias_init = -2.71

if txt_name == 'base':
config.optax_name = 'scale_by_adam'
Expand All @@ -130,48 +105,11 @@ def get_config(arg=None):

config.grad_clip_norm = 1.0

# Eval section (Both few-shot and zero-shot)
eval_common = dict(
type='proj.image_text.contrastive',
use_global_batch=True,
log_steps=500 if not arg.runlocal else 5,
)
config.evals = {}
sub = '[:4]' if arg.runlocal else ''
config.evals.val = {
**eval_common,
'data': dict(name=config.input.data.name, split=f'val{sub}'),
'pp_fn': pp_eval,
}
config.evals.coco = {
**eval_common,
'data': dict(name='coco_captions', split=f'val{sub}'),
'pp_fn': (
f'decode|resize({arg.res})|value_range(-1,1)'
f'|flatten|{tokenizer("captions/text")}|keep("image", "labels")'),
}
config.evals.imagenet = {
**eval_common,
'data': dict(name='imagenet2012', split=f'validation{sub}'),
'pp_fn': (
f'decode|resize({arg.res})|value_range(-1,1)'
'|clip_i1k_label_names'
f'|{tokenizer("labels")}|keep("image", "labels")'),
}

config.evals.disclf = {}
config.evals.disclf.pp_img = f'resize({arg.res})|value_range(-1,1)'
config.evals.disclf.pp_txt = tokenizer('texts')
config.evals.disclf.type = 'proj.image_text.discriminative_classifier'
config.evals.disclf.prefix = 'z/0shot/'
config.evals.disclf.log_steps = eval_common['log_steps']
config.evals.retrieval_coco = common.get_coco(
pp_img=f'resize({arg.res})|value_range(-1, 1)',
pp_txt=tokenizer('texts'),
log_steps=config.evals.disclf.log_steps,
log_steps=1000,
)

config.seed = 0
config.l = config.m = 0

return config
4 changes: 2 additions & 2 deletions big_vision/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _shard(x):
sharding = NamedSharding(mesh, P("devices"))
local_ds = mesh.local_devices

x = np.asarray(memoryview(x)) # No-copy: http://shortn/_KM5whIEtWI
x = np.asarray(memoryview(x)) # No-copy: http://(internal link)
xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds)

global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:])
Expand All @@ -237,7 +237,7 @@ def _shard(x):


def shard_and_put(x, shard=True, put=True):
x = np.asarray(memoryview(x)) # No-copy conversion: http://shortn/_KM5whIEtWI
x = np.asarray(memoryview(x)) # No-copy conversion: http://(internal link)
if shard:
x = einops.rearrange(x, "(d l) ... -> d l ...", d=jax.local_device_count())
if shard and put: # Only works for pmap (for now).
Expand Down
20 changes: 20 additions & 0 deletions big_vision/models/proj/image_text/two_towers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __call__(self, image, text=None, **kw):

def load(init_params, init_files, model_cfg, img_load_kw={}, txt_load_kw={}): # pylint: disable=dangerous-default-value
"""Loads both towers, `init_files` is now a dict with `img` and `txt` keys."""
if isinstance(init_files, str):
init_files = VANITY_NAMES.get(init_files, init_files)

if isinstance(init_files, str):
# A shortcut for a single file checkpoint of a two_towers model.
if "bias_init" in model_cfg.keys():
Expand Down Expand Up @@ -132,3 +135,20 @@ def load(init_params, init_files, model_cfg, img_load_kw={}, txt_load_kw={}): #
f"a typo. Here it is: {init_files}")

return restored_params


# Shortcut names for some canonical paper checkpoints:
VANITY_NAMES = {
# pylint: disable=line-too-long
# SigLIP image encoder checkpoints from https://arxiv.org/abs/2303.15343
"SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz",
"SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz",
"SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz",
"SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz",
"SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz",
"SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz",
"SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz",
"SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz",
"SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz",
# pylint: enable=line-too-long
}
43 changes: 37 additions & 6 deletions big_vision/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,23 +379,45 @@ def stack(*values):
return params_scan


def scan_to_pyloop(params_scan):
"""Converts a lax.scan ViT checkpoint to a python for-loop based one."""
# See comment in pyloop_to_scan.

params_scan = jax.tree_map(lambda x: x, params_scan) # Structural copy
t = params_scan["Transformer"]

# Find out how many encoderblocks there are
depth = len(t["encoderblock"]["LayerNorm_0"]["bias"])

# Create that many encoderblocks, each with their slice of their sub-pytree.
for lyr in range(depth):
block = jax.tree_map(lambda x, lyr=lyr: x[lyr], t["encoderblock"])
t[f"encoderblock_{lyr}"] = block

del t["encoderblock"]
return params_scan


def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name because we had to CamelCase above.
"""Load init from checkpoint, both old model and this one. +Hi-res posemb."""
del model_cfg

init_file = VANITY_NAMES.get(init_file, init_file)
restored_params = utils.load_params(init_file)

restored_params = fix_old_checkpoints(restored_params)

if init_params and "encoderblock" in init_params["Transformer"]:
# Detect attempts to load non-scan checkpoint into scan model.
if (model_cfg.get("scan") and
"encoderblock" not in restored_params["Transformer"]):
restored_params = pyloop_to_scan(restored_params)
# TODO: detect and convert the other way around too.
if (not model_cfg.get("scan")
and "encoderblock" in restored_params["Transformer"]):
restored_params = scan_to_pyloop(restored_params)

# possibly use the random init for some of the params (such as, the head).
restored_params = common.merge_params(restored_params, init_params, dont_load)

# resample posemb if needed.
# TODO: Take this from model_cfg to avoid need for init_params.
if init_params and "pos_embedding" in init_params:
restored_params["pos_embedding"] = resample_posemb(
old=restored_params["pos_embedding"],
Expand All @@ -406,7 +428,6 @@ def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=in

# Shortcut names for some canonical paper checkpoints:
VANITY_NAMES = {
# pylint: disable=line-too-long
# pylint: disable=line-too-long
# Recommended models from https://arxiv.org/abs/2106.10270
# Many more models at https://github.com/google-research/vision_transformer
Expand Down Expand Up @@ -437,6 +458,16 @@ def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=in
"deit3_L_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_224_21k.npz",
"deit3_L_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_1k.npz",
"deit3_L_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_21k.npz",
# pylint: disable=line-too-long

# SigLIP image encoder checkpoints from https://arxiv.org/abs/2303.15343
"SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz:img",
"SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz:img",
"SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz:img",
"SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz:img",
"SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz:img",
"SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz:img",
"SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz:img",
"SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz:img",
"SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz:img",
# pylint: enable=line-too-long
}
5 changes: 3 additions & 2 deletions big_vision/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
numpy>=1.26
absl-py
clu
git+https://github.com/google/CommonLoopUtils
einops
flax
optax
git+https://github.com/google/flaxformer
git+https://github.com/akolesnikoff/panopticapi.git@mute
overrides
tensorflow
tensorflow-cpu
tfds-nightly
tensorflow-addons
tensorflow-text
Expand Down
Loading

0 comments on commit 3b8e5ab

Please sign in to comment.