Skip to content

Commit

Permalink
style:make style
Browse files Browse the repository at this point in the history
  • Loading branch information
RUFFY-369 committed Oct 19, 2024
1 parent cf3408b commit 8438666
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
22 changes: 14 additions & 8 deletions src/transformers/models/propainter/convert_propainter_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

import argparse
import os
import re

import numpy as np
import re
import torch
from datasets import load_dataset

Expand All @@ -38,9 +38,9 @@
(r"cnet", r"context_network"),
(r"update_block", r"update_block"),
(r"module\.(fnet|cnet|update_block)", r"optical_flow_model.\1"),
(r'layer(\d+)\.(\d+)', lambda m: f"resblocks.{(int(m.group(1)) - 1) * 2 + int(m.group(2))}"),
(r'convc', 'conv_corr'),
(r'convf', 'conv_flow')
(r"layer(\d+)\.(\d+)", lambda m: f"resblocks.{(int(m.group(1)) - 1) * 2 + int(m.group(2))}"),
(r"convc", "conv_corr"),
(r"convf", "conv_flow"),
]

rename_rules_flow_completion = [
Expand All @@ -51,14 +51,20 @@
(r"decoder2", r"flow_completion_net.decoder2"),
(r"upsample", r"flow_completion_net.upsample"),
(r"mid_dilation", r"flow_completion_net.intermediate_dilation"),
(r"feat_prop_module\.deform_align\.backward_", r"flow_completion_net.feature_propagation_module.deform_align.backward_"),
(r"feat_prop_module\.deform_align\.forward_", r"flow_completion_net.feature_propagation_module.deform_align.forward_"),
(
r"feat_prop_module\.deform_align\.backward_",
r"flow_completion_net.feature_propagation_module.deform_align.backward_",
),
(
r"feat_prop_module\.deform_align\.forward_",
r"flow_completion_net.feature_propagation_module.deform_align.forward_",
),
(r"feat_prop_module\.backbone\.backward_", r"flow_completion_net.feature_propagation_module.backbone.backward_"),
(r"feat_prop_module\.backbone\.forward_", r"flow_completion_net.feature_propagation_module.backbone.forward_"),
(r"feat_prop_module\.fusion", r"flow_completion_net.feature_propagation_module.fusion"),
(r"edgeDetector\.projection", r"flow_completion_net.edgeDetector.projection"),
(r"edgeDetector\.mid_layer", r"flow_completion_net.edgeDetector.intermediate_layer"),
(r"edgeDetector\.out_layer", r"flow_completion_net.edgeDetector.out_layer")
(r"edgeDetector\.out_layer", r"flow_completion_net.edgeDetector.out_layer"),
]

rename_rules_inpaint_generator = [
Expand All @@ -68,7 +74,7 @@
(r"sc", r"inpaint_generator.soft_comp"),
(r"feat_prop_module\.", r"inpaint_generator.feature_propagation_module."),
(r"transformers\.transformer\.", r"inpaint_generator.transformers.transformer."),
(r"norm", r"layer_norm")
(r"norm", r"layer_norm"),
]


Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/propainter/modeling_propainter.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def __init__(self, config: ProPainterConfig, output_dim: int = 128, norm_fn: str
ProPainterResidualBlock(config, in_channel, num_channels, norm_fn, stride),
ProPainterResidualBlock(config, num_channels, num_channels, norm_fn, stride=1),
]
for in_channel, num_channels, stride in zip(config.in_channels, config.channels, config.multi_level_conv_stride)
for in_channel, num_channels, stride in zip(
config.in_channels, config.channels, config.multi_level_conv_stride
)
]
# using itertools makes flattening a little faster :)
self.resblocks = nn.ModuleList(list(itertools.chain.from_iterable(self.resblocks)))
Expand Down

0 comments on commit 8438666

Please sign in to comment.