diff --git a/timm/models/levit.py b/timm/models/levit.py index 16186cae7..577fc5f2d 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -763,17 +763,18 @@ def checkpoint_filter_fn(state_dict, model): # filter out attn biases, should not have been persistent state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k} - D = model.state_dict() - out_dict = {} - for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()): - if va.ndim == 4 and vb.ndim == 2: - vb = vb[:, :, None, None] - if va.shape != vb.shape: - # head or first-conv shapes may change for fine-tune - assert 'head' in ka or 'stem.conv1.linear' in ka - out_dict[ka] = vb - - return out_dict + # NOTE: old weight conversion code, disabled + # D = model.state_dict() + # out_dict = {} + # for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()): + # if va.ndim == 4 and vb.ndim == 2: + # vb = vb[:, :, None, None] + # if va.shape != vb.shape: + # # head or first-conv shapes may change for fine-tune + # assert 'head' in ka or 'stem.conv1.linear' in ka + # out_dict[ka] = vb + + return state_dict model_cfgs = dict(