Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support regular expression in the mapping arg of copy_model_state #6917

Merged
merged 11 commits into from
Sep 12, 2023
12 changes: 7 additions & 5 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,13 @@ def copy_model_state(
updated_keys.append(dst_key)
for s in mapping if mapping else {}:
dst_key = f"{dst_prefix}{mapping[s]}"
if dst_key in dst_dict and dst_key not in to_skip:
if dst_dict[dst_key].shape != src_dict[s].shape:
warnings.warn(f"Param. shape changed from {dst_dict[dst_key].shape} to {src_dict[s].shape}.")
dst_dict[dst_key] = src_dict[s]
updated_keys.append(dst_key)
src_keys = sorted({s_key for s_key in src_dict if s_key not in to_skip and re.compile(s).search(s_key)})
dst_keys = sorted({d_key for d_key in dst_dict if d_key not in to_skip and re.compile(dst_key).search(d_key)})
for _src_key, _dst_key in zip(src_keys, dst_keys):
wyli marked this conversation as resolved.
Show resolved Hide resolved
if dst_dict[_dst_key].shape != src_dict[_src_key].shape:
warnings.warn(f"Param. shape changed from {dst_dict[_dst_key].shape} to {src_dict[_src_key].shape}.")
dst_dict[_dst_key] = src_dict[_src_key]
updated_keys.append(_dst_key)

updated_keys = sorted(set(updated_keys))
unchanged_keys = sorted(set(all_keys).difference(updated_keys))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_copy_model_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_set_map_across(self, device_0, device_1):
model_two.to(device_1)
# test weight map
model_dict, ch, unch = copy_model_state(
model_one, model_two, mapping={"layer_1.weight": "layer.weight", "layer_1.bias": "layer_1.weight"}
model_one, model_two, mapping={"layer_1.weight": "^layer.weight", "layer_1.bias": "layer_1.weight"}
)
model_one.load_state_dict(model_dict)
x = np.random.randn(4, 10)
Expand Down