Skip to content

Commit

Permalink
Merge pull request #443 from IBM/jaionet-compatibility-fix-script
Browse files Browse the repository at this point in the history
fix model for backwards compatibility
  • Loading branch information
Joao-L-S-Almeida authored Feb 20, 2025
2 parents d388a95 + 6267453 commit 66bb0f6
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions examples/scripts/fix_backwards_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3
import argparse
import os
import torch

"""
usage:
python3 ./scripts/fix_backwards_copatibility.py <file>
"""

homedir = os.path.expanduser('~')
cwd = os.getcwd()

# Get filename of checkpoint or model file to correct
if __name__ == '__main__':
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Convert model for backwards compatibility on terratorch versions 0.99 or higher')
parser.add_argument('file',
action='store',
metavar='INPUT_FILE',
type=str,
help='Checkpoint file or model to be corrected for backwards compatibility on terratorch versions 0.99 or higher')
arg = parser.parse_args()

# Input file
path_in = arg.file
print('path in:', path_in)
path_out = (arg.file).split('.')[0]+'_Fixed.'+(arg.file).split('.')[1]
print('path out:', path_out)

state_dict = torch.load(path_in, map_location=torch.device('cpu'))
state_dict_renamed = {}

for k, v in state_dict.items():
# remove the module. part
if k == 'state_dict':
state_dict_renamed[k] = {}
for k1, v1 in v.items():
if 'model.encoder.' in k1:
state_dict_renamed[k][k1.replace('model.encoder.', 'model.encoder._timm_module.')] = v1
else:
state_dict_renamed[k][k1] = v1
else:
state_dict_renamed[k] = v


torch.save(state_dict_renamed, path_out)

0 comments on commit 66bb0f6

Please sign in to comment.