Skip to content

Commit

Permalink
modify vanilla_model functio to support parametrization
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Dec 9, 2024
1 parent c6bf80f commit 96e5249
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions deel/torchlip/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def vanilla_model(model: nn.Module):
model (nn.Module): Lipschitz neural network
"""
for n, module in model.named_children():
if len(list(module.children())) > 0:
# compound module, go inside it
vanilla_model(module)

if isinstance(module, LipschitzModule):
# simple module
setattr(model, n, module.vanilla_export())
elif len(list(module.children())) > 0:
# compound module, go inside it
vanilla_model(module)



class _LipschitzCoefMultiplication(nn.Module):
Expand Down

0 comments on commit 96e5249

Please sign in to comment.