Skip to content

Commit

Permalink
Merge pull request #417 from IBM/fix/check/load
Browse files Browse the repository at this point in the history
Adjusting the weights keys when necessary
  • Loading branch information
romeokienzler authored Feb 21, 2025
2 parents 66bb0f6 + 930f195 commit a5d7a9d
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 12 deletions.
5 changes: 5 additions & 0 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@

logger = logging.getLogger("terratorch")

from terratorch.utils import remove_unexpected_prefix

def flatten(list_of_lists):
return list(itertools.chain.from_iterable(list_of_lists))

Expand Down Expand Up @@ -496,6 +498,9 @@ def __init__(
weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
if "state_dict" in weights:
weights = weights["state_dict"]
# It removes a residual prefix (related to timm) from older
# checkpoints.
weights = remove_unexpected_prefix(weights)
weights = {k.replace("model.", ""): v for k, v in weights.items() if k.startswith("model.")}
self.model.model.load_state_dict(weights)

Expand Down
36 changes: 35 additions & 1 deletion terratorch/models/backbones/prithvi_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ def weights_are_swin_implementation(state_dict: dict[str, torch.Tensor]):
return True
return False

# Identifying when a prefix is being used in the checkpoints
# it will identify it.
def identify_prefix(state_dict, model):

state_dict_ = model.state_dict()

prefix = list(state_dict.keys())[0].replace(list(state_dict_.keys())[0], "")

return prefix

# Replacing "_" with "." when necessary.
def adapt_prefix(key):
if key.startswith("stages_"):
key = key.replace("stages_", "stages.")
return key

def checkpoint_filter_fn(state_dict: dict[str, torch.Tensor], model: torch.nn.Module, pretrained_bands, model_bands):
"""convert patch embedding weight from manual patchify + linear proj to conv"""
Expand Down Expand Up @@ -134,9 +149,27 @@ def checkpoint_filter_fn(state_dict: dict[str, torch.Tensor], model: torch.nn.Mo
state_dict[k] = v

relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
# Sometimes the checkpoints can contain an unexpected prefix that must be
# removed.
prefix = identify_prefix(state_dict, model)

for table_key in relative_position_bias_table_keys:

# The checkpoints can sometimes contain unexpected prefixes.
# TODO Guarantee that it will not happen in the future.
if prefix:
table_key_ = table_key.replace(prefix, "")
else:
table_key_ = table_key

# In an unexpected behavior, the prefix can sometimes contain
# "_" or ".". We are enforcing ".".
# TODO Standardize it.
table_key_ = adapt_prefix(table_key_)

table_pretrained = state_dict[table_key]
table_current = model.state_dict()[table_key]

table_current = model.state_dict()[table_key_]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
Expand Down Expand Up @@ -190,6 +223,7 @@ def _create_swin_mmseg_transformer(
def checkpoint_filter_wrapper_fn(state_dict, model):
return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands)

# TODO Totally remove the usage of timm for Swin in the future.
# When the pretrained configuration is not available in HF, we shift to
# pretrained=False
try:
Expand Down
5 changes: 3 additions & 2 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def checkpoint_filter_fn_vit(

state_dict = clean_dict

state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)
state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands, encoder_only=True)

return state_dict

Expand Down Expand Up @@ -153,7 +153,7 @@ def checkpoint_filter_fn_mae(

state_dict = clean_dict

state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)
state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands, encoder_only=False)

return state_dict

Expand Down Expand Up @@ -214,6 +214,7 @@ def _create_prithvi(
# Load model from checkpoint
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands)

loaded_keys = model.load_state_dict(state_dict, strict=False)
if loaded_keys.missing_keys:
logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}")
Expand Down
89 changes: 82 additions & 7 deletions terratorch/models/backbones/select_patch_embed_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,74 @@ def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoi
checkpoint_shape = [checkpoint_patch_embed.shape[i] for i in range(len(checkpoint_patch_embed.shape)) if i != 1]
return model_shape == checkpoint_shape

def get_state_dict(state_dict):

def search_state_dict(keys):
key = 0
for k in keys:
if k.endswith("state_dict"):
key = k
break
return key

state_dict_key = search_state_dict(state_dict.keys())

if state_dict_key:
return state_dict[state_dict_key]
else:
return state_dict

def get_common_prefix(keys):

keys_big_list = []

keys = list(keys)
keys.pop(-1)

for k in keys:
keys_big_list.append(set(k.split(".")))
prefix_list = set.intersection(*keys_big_list)

if len(prefix_list) > 1:
prefix = ".".join(prefix_list)
else:
prefix = prefix_list.pop()

return prefix + "."

def get_proj_key(state_dict, encoder_only=True, return_prefix=False):

proj_key = None

for key in state_dict.keys():
if key.endswith('patch_embed.proj.weight') or key.endswith('patch_embed.projection.weight'):
proj_key = key
break


if return_prefix and proj_key:
if encoder_only:
for sufix in ['patch_embed.proj.weight', 'patch_embed.projection.weight']:
if proj_key.endswith(sufix):
prefix = proj_key.replace(sufix, "")
break
else:
prefix = get_common_prefix(state_dict.keys())
else:
prefix = None

return proj_key, prefix

def remove_prefixes(state_dict, prefix):
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k.replace(prefix, "")] = v
return new_state_dict

def select_patch_embed_weights(
state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int | OpticalBands| SARBands], model_bands: list[HLSBands | int | OpticalBands| SARBands], proj_key: str | None = None
) -> dict:
state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int | OpticalBands| SARBands], model_bands: list[HLSBands | int | OpticalBands| SARBands],
proj_key: str | None = None, encoder_only:bool=True) -> dict:

"""Filter out the patch embedding weights according to the bands being used.
If a band exists in the pretrained_bands, but not in model_bands, drop it.
If a band exists in model_bands, but not pretrained_bands, randomly initialize those weights.
Expand All @@ -38,18 +103,25 @@ def select_patch_embed_weights(
"""
if (type(pretrained_bands) == type(model_bands)) | (type(pretrained_bands) == int) | (type(model_bands) == int):

state_dict = get_state_dict(state_dict)
prefix = None # we expect no prefix will be necessary in principle

if proj_key is None:
# Search for patch embedding weight in state dict
for key in state_dict.keys():
if key.endswith('patch_embed.proj.weight') or key.endswith('patch_embed.projection.weight'):
proj_key = key
break
proj_key, prefix = get_proj_key(state_dict, return_prefix=True, encoder_only=encoder_only)
if proj_key is None or proj_key not in state_dict:
raise Exception("Could not find key for patch embed weight in state_dict.")

patch_embed_weight = state_dict[proj_key]

temp_weight = model.state_dict()[proj_key].clone()
# It seems `proj_key` can have different names for
# the checkpoint and the model instance
proj_key_, _ = get_proj_key(model.state_dict(), encoder_only=encoder_only)

if proj_key_:
temp_weight = model.state_dict()[proj_key_].clone()
else:
temp_weight = model.state_dict()[proj_key].clone()

# only do this if the patch size and tubelet size match. If not, start with random weights
if patch_embed_weights_are_compatible(temp_weight, patch_embed_weight):
Expand All @@ -68,4 +140,7 @@ def select_patch_embed_weights(

state_dict[proj_key] = temp_weight

if prefix:
state_dict = remove_prefixes(state_dict, prefix)

return state_dict
15 changes: 13 additions & 2 deletions terratorch/registry/timm_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from torch import nn

from terratorch.registry import BACKBONE_REGISTRY
from terratorch.utils import remove_unexpected_prefix


class TimmBackboneWrapper(nn.Module):
class TimmBackboneWrapper_(nn.Module):
def __init__(self, timm_module: nn.Module) -> None:
super().__init__()
self._timm_module = timm_module
Expand All @@ -22,6 +22,17 @@ def out_channels(self):
def forward(self, *args, **kwargs) -> list[torch.Tensor]:
return self._timm_module(*args, **kwargs)

class TimmBackboneWrapper(nn.Module):
def __init__(self, timm_module: nn.Module) -> None:
super().__init__()
self._modules.update(timm_module._modules)
self._out_channels = timm_module.feature_info.channels()
# for backwards compatibility for times before necks
self.prepare_features_for_image_model = getattr(timm_module, "prepare_features_for_image_model", lambda x: x)
self.forward = timm_module.forward
@property
def out_channels(self):
return self._out_channels

class TimmRegistry(Set):
"""Registry wrapper for timm"""
Expand Down
16 changes: 16 additions & 0 deletions terratorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,19 @@ def compute_float_mask_statistics(dataloader: DataLoader) -> dict[str, float]:
variance = sum_squared / n_data
std = math.sqrt(variance)
return {"mean": mean, "std": std}

# TODO remove it for future releases
def remove_unexpected_prefix(state_dict):
state_dict_ = {}
for k, v in state_dict.items():
keys = k.split(".")
if "_timm_module" in keys:
index = keys.index("_timm_module")
keys.pop(index)
k_ = ".".join(keys)
else:
k_ = k
state_dict_[k_] = v
return state_dict_


0 comments on commit a5d7a9d

Please sign in to comment.