Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support regular expression in the mapping arg of copy_model_state #6917

Merged
merged 11 commits into from
Sep 12, 2023

Conversation

KumoLiu
Copy link
Contributor

@KumoLiu KumoLiu commented Aug 31, 2023

Part of #6552.

Description

After PR #6835, we have added copy_model_args in the load API which can help us update the state_dict flexibly.
https://github.com/KumoLiu/MONAI/blob/93a149a611b66153cf804b31a7b36a939e2e593a/monai/bundle/scripts.py#L397

Given this issue, we need to be able to filter the model's weights flexibly.
In copy_model_state, we already have a "mapping" arg, the filter will be more flexible if we can support regular expression in the mapping. This PR mainly added the support for regular expression for "mapping" arg.

In the example in this issue, after this PR, we can do something like:

exclude_vars = "encoder.mask_token|encoder.norm.weight|encoder.norm.bias|out.conv.conv.weight|out.conv.conv.bias"
mapping={"encoder.layers(.*).0.0.": "swinViT.layers(.*).0."}
dst_dict, updated_keys, unchanged_keys = copy_model_state(
       model, ssl_weights, exclude_vars=exclude_vars, mapping=mapping
)

Additionally, based on the comments of Eric here, I totally agree, we could add a handler to make the pipeline easier to implement, but perhaps this task is no need to set as a "BundleTodo" for MONAIv1.3 but as an enhancement for MONAI near future.
What do you think? @ericspod @wyli @Nic-Ma

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

KumoLiu and others added 2 commits August 31, 2023 14:29
Signed-off-by: KumoLiu <yunl@nvidia.com>
monai/networks/utils.py Outdated Show resolved Hide resolved
@KumoLiu KumoLiu requested review from ericspod and Nic-Ma August 31, 2023 10:31
@KumoLiu
Copy link
Contributor Author

KumoLiu commented Sep 8, 2023

According to Wenqi's comments and #6552 (comment), we can leave mapping doesn't support regular expression here, but add a filter_func in copy_model_state which defaults to None.
For filter_func it doesn't matter how it is implemented inside, the important thing is just to return a pair of keys and values that will need to be filtered, and return None otherwise.

Since copy_model_state is also used in CheckpointLoader, we can also easily pass our filter_func into it to help filter weights, researchers can also provide some common-used filters for users to use.

checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0]

def filter_swinunetr(k, v):
    if k in [
        "out.conv.conv.weight",
        "out.conv.conv.bias",
    ]:
        return None

    if k[:8] == "encoder.":
        if k[8:19] == "patch_embed":
            new_key = "swinViT." + k[8:]
        else:
            new_key = "swinViT." + k[8:18] + k[20:]

        return new_key, v
    else:
        return None

dst_dict, updated_keys, unchanged_keys = copy_model_state(
       model, ssl_weights, filter_func=filter_swinunetr
)

If you think this works, I can continue with this PR. Thanks.

This reverts commit ed0ea9b.

Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
@KumoLiu KumoLiu marked this pull request as ready for review September 12, 2023 06:30
@KumoLiu KumoLiu requested a review from wyli September 12, 2023 06:30
Signed-off-by: KumoLiu <yunl@nvidia.com>
Copy link
Contributor

@wyli wyli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, the test case show the usage,cc @ericspod @Nic-Ma @vikashg please share any further comments

Signed-off-by: KumoLiu <yunl@nvidia.com>
monai/networks/nets/swin_unetr.py Show resolved Hide resolved
@wyli
Copy link
Contributor

wyli commented Sep 12, 2023

/build

@wyli wyli enabled auto-merge (squash) September 12, 2023 19:26
@wyli wyli merged commit 392c5c1 into Project-MONAI:dev Sep 12, 2023
@KumoLiu KumoLiu deleted the update-copy-model-state branch September 13, 2023 03:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants