Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models with no ZIG? #57

Open
iamanigeeit opened this issue Mar 1, 2024 · 19 comments
Open

Models with no ZIG? #57

iamanigeeit opened this issue Mar 1, 2024 · 19 comments

Comments

@iamanigeeit
Copy link
Contributor

Hi @tianyic ,

I am trying to use OTO on speech models (FastSpeech2) and rewrote parts to make sure all the pytorch ops are supported in ONNX.

However, i found that nothing was pruned. When i run

oto = OTO(model, dummy_input=dummy_input)
optimizers = [oto.hesso(**args.optim_conf)]

I get
hesso.total_num_groups = 0
Target redundant groups per period: [0]

Does this mean there are no zero-invariant groups in the model? This is strange, because there are conv layers in transformer encoder/decoder. Reference code

Any help appreciated, thanks!

@tianyic
Copy link
Owner

tianyic commented Mar 1, 2024

@iamanigeeit

It could be resolved. To facilitate troubleshooting, could you share the rendered pruning dependency graph or the onnx model you edited. We actually meet one similar case when employing on StyleGANv2. The root cause there is due to all basic modules beingcustomized including linear and conv stylegan's customized operator, which are not covered by OTO's operator list. The issue was resolved after rewriting all customized conv and linear layers back to standard ones. But the root cause may be different.

@iamanigeeit
Copy link
Contributor Author

@tianyic Thanks for the reply. By pruning dependency graph, you mean i should call oto.build_dot right?

ONNX doesn't support STFT and l1_loss, so i changed used STFT outputs as the inputs and changed l1_loss to MSE loss. Must all the operators in my model be in the OTO operator list?

@tianyic
Copy link
Owner

tianyic commented Mar 2, 2024

You can use

oto.visualize(view=False, out_dir=PATH, display_params=True)

to get render pruning dependency graph. pygraphviz or graphviz may be needed to be installed. Please go through the README in visualization.

The nn.Modules in the target DNN does not need to be all included by the OTO operator list, yet needs to be composed by the basic modules or composed modules shown in OTO operators list. Of course, people could add new basic modules too.

@iamanigeeit
Copy link
Contributor Author

iamanigeeit commented Mar 3, 2024

I have almost 9000 nodes... oto.visualize outputs an empty pdf, and i am still waiting for $ dot graph.gv/ESPnetTTSModel_pruning_dependency to finish rendering after 1 hour.

Is there a way to analyse the dependency file directly?

@tianyic
Copy link
Owner

tianyic commented Mar 3, 2024

@iamanigeeit

It could be resolved. We meet similar issue when tackling new DNNs with complicated structures. Indeed, graphviz has some trouble to compute the graph layout with about over 5000 nodes.

Please consider my following suggestions that are frequently used by us to support new DNNs.

  • Feed just one or two layers, yet not the whole DNN into sanity check and rendering. For example, we use 4 layers in LLAMA during sanity check. It could dramatically reduce the node and edge quantity. Once sanity check over smaller counterpart passed, then go through the whole DNNs.

  • Some modern DNNs are constructed by composing multiple complicated architectures together. We also prefer decomposing the whole DNNs into multiple modules, feeding multiple modules to create multiple OTO instances. For example, for an encoder-decoder DNN, we have conducted multiple experiments as below

oto_encoder = OTO(model.encoder, dummy_input_encoder)
oto_encoder.visualize()
optimizer_encoder = oto_encoder.some_optimizer()

oto_decoder = OTO(model.decoder, dummy_input_decoder)
oto_decoder.visualize()
optimizer_decoder = oto_decoder.some_optimizer()

@iamanigeeit
Copy link
Contributor Author

iamanigeeit commented Mar 5, 2024

@tianyic Thanks for the suggestions.

I re-ran it with batch size 2 for the dummy input and got this graph.

It looks like the entire graph is one group and every node has dashed lines, meaning unprunable. What should i do?

@tianyic
Copy link
Owner

tianyic commented Mar 5, 2024

@iamanigeeit

Good to know that you got the graph, which helps troubleshooting a lot. I feel this DNN (a transformer) should be smoothly supported after properly adding (at-most) two operators in operator.py.

Please consider my following suggestions.

  • Double check that all variables are trainable requires_grad=True before feeding into OTO. For example, in the pretrained Yolov5, all variables are initialized as non-trainable, then no-prunable node groups at all if creating OTO instance upon it.

  • One operator needs to be added, the conv, which could be prunable. Please check the DNN's module list, add the corresponding module into basic module list. Since this operator looks similarly to linear, you could try to directly map it to linearOTO at first.

BASIC_MODULES = {
    'ConvTranspose2d': ConvTranspose2dOTO,
    'Conv2d': Conv2dOTO,
    'ModulatedConv2d': Conv2dOTO, # For stagelightv2
    'EqualLinear': LinearOTO, # For stagelightv2
    'Linear': LinearOTO,
    'BatchNorm2d': BatchNormOTO,
    'InstanceNorm2d': InstanceNormOTO,
    'GroupNorm': GroupNormOTO,
    'Embedding': EmbeddingOTO,
    
    'LlamaRMSNorm': LayerNormOTO,
    'LayerNorm': LayerNormOTO,
    
    'PReLU': PReLUOTO,
     # Add the module for conv here, key is module name (capital case sensitive), value is the operator in OTO
     'moduleForConv': LinearOTO
}
Screenshot 2024-03-05 at 8 33 41 AM
  • Similarly for attention layer, add the module name, operator pair into composed module list. You could try BaseMultiHeadAttentionOTO at first.
COMPOSED_MODULES = {
    'LlamaAttention': LlamaAttentionOTO,
    'SelfAttention': BaseMultiHeadAttentionOTO,
    'BertAttention': BertAttentionOTO,
    'PhiMHA': PhiAttentionOTO,
    'LoraLinear': LoraLinearOTO,
    'LoraEmbedding': LoraEmbeddingOTO,

    # Multihead-Attention for TTS
    'dummyTTSAttention': BaseMultiHeadAttentionOTO
}
Screenshot 2024-03-05 at 8 43 43 AM

Please share the re-rendered pruning dependency graph after the above steps, or let me know if question.

Lastly, OTO requires the DNN module composed by the module listed in operator.py. If the basic nn.modules in DNN are not covered yet, we need to add them into the operator list.

@tianyic
Copy link
Owner

tianyic commented Mar 5, 2024

I would expect that after completing the above suggestions, the pruning dependency graph would have multiple node groups fulfilled by solid color, i.e., prunable.

@iamanigeeit
Copy link
Contributor Author

iamanigeeit commented Mar 14, 2024

@tianyic Thank you for the suggestions! I can see different colours on the graph now.
ESPnetTTSModel_pruning_dependency.pdf

However, I am getting repeated parameters:

super(HESSO, self).__init__(params, defaults)
...
  File "/home/perry/miniconda3/envs/s5/lib/python3.10/site-packages/torch/optim/optimizer.py", line 893, in add_param_group
ValueError: some parameters appear in more than one parameter group

There are 3 categories of duplicate params: (1) in conv layers, (2) in norm layers and (3) attention linear_out.

(1) Repeated conv params:

tts.pitch_predictor.conv.0.2.weight
tts.pitch_predictor.conv.0.2.bias
tts.pitch_predictor.conv.0.2.weight
tts.pitch_predictor.conv.0.2.bias
tts.pitch_predictor.conv.0.0.weight
tts.pitch_predictor.conv.0.0.bias
...

Does this have to do with setting Conv1d => LinearOTO? Maybe i should adapt Conv2d to Conv1d?

(2) Repeated LayerNorm params:

tts.decoder.encoders.0.norm1.weight
tts.decoder.encoders.0.norm1.bias
tts.decoder.encoders.0.norm1.weight
tts.decoder.encoders.0.norm1.bias

(3) Repeated MultiHeadAttention params (only in linear_out, not in q, k, v)

tts.decoder.encoders.0.self_attn.linear_out.weight
tts.decoder.encoders.0.self_attn.linear_out.bias
tts.decoder.encoders.0.self_attn.linear_out.weight
tts.decoder.encoders.0.self_attn.linear_out.bias
...

@iamanigeeit
Copy link
Contributor Author

I have skipped the duplicate params by changing
https://github.com/tianyic/only_train_once/blob/466aa9d31c19786d8a7aa10701a7b87f655931c0/only_train_once/optimizer/hesso.py#L62

to

        groups_to_add = []
        existing_params = set()
        for param_group in params:
            p_names = []
            ps = []
            op_names = []
            p_transform = []
            for name, p, op, transform in zip(
                    param_group['p_names'], param_group['params'],
                    param_group['op_names'], param_group['p_transform']
            ):
                if name in existing_params:
                    print(f'Found duplicate param: {name}, ignoring')
                else:
                    p_names.append(name)
                    ps.append(p)
                    op_names.append(op)
                    p_transform.append(transform)
                    existing_params.add(name)
            if p_names:
                group_to_add = dict(param_group)
                group_to_add['p_names'] = p_names
                group_to_add['params'] = ps
                group_to_add['op_names'] = op_names
                group_to_add['p_transform'] = p_transform
                groups_to_add.append(group_to_add)
        super(HESSO, self).__init__(groups_to_add, defaults)

This does make OTO run, but i don't know if it's correct or not.

@tianyic
Copy link
Owner

tianyic commented Mar 14, 2024

@iamanigeeit

Glad to hear that you have got some node groups as prunable. Please consider my following suggestions or comments.

  • You are right. It is better to adapt conv2d to conv1d. When doing so, to relieve unnecessary difficulties, please discard the lines related to grouped conv2d. Just consider it as groups=1. To validate if the operator works, I highly recommend to create a demonet consisting of conv1d, then proceed a sanity check, e.g., test_convtranspose_in_case2 and demonet_convtranspose_in_case2

  • Regarding shared parameter, there exist several ways to tackle weighted shared cases, yet I need more information to provide the most appropriate solution. Based on the pruning dependency graph, I only find a unique instance of the presented parameter name. For example, there is only one tts.pitch_predictor.conv.0.2.weight in ESPnetTTSModel_pruning_dependency.pdf. If so, I would expect the parameter groups then should have disjoint parameters, otherwise, there should be one parameter displaying in multiple node groups. Could you provide me more information regarding how OTO is created and where the weighted sharing appear?

@iamanigeeit
Copy link
Contributor Author

iamanigeeit commented Mar 15, 2024

@tianyic

I simply load the text-to-speech model and pass the dummy input to OTO:

from only_train_once import OTO
import pickle
dummy_input_path = args.optim_conf.pop("dummy_input_path")
with open(dummy_input_path, "rb") as f:
    dummy_input = pickle.load(f)
for name, tensor in dummy_input.items():
    dummy_input[name] = tensor[0:2].to("cuda")  # let batch size = 2 to make it work
oto = OTO(model, dummy_input=dummy_input)
optimizers = [oto.hesso(**args.optim_conf)]

I have added at https://github.com/tianyic/only_train_once/blob/466aa9d31c19786d8a7aa10701a7b87f655931c0/only_train_once/__init__.py#L46-L49

            self.visualize(view=False, out_dir="exp/fs2_oto/graph_3.gv", display_params=True)
            params = self._graph.get_param_groups()
            import pickle
            with open('params.pkl', 'wb') as f:
                pickle.dump(list(params), f)
            exit()

Graph: ESPnetTTSModel_pruning_dependency.pdf
Params: params_no_tensor.txt (i have removed the tensors to make the file smaller, it should just work with pickle.load)

The duplicate p_names are not reflected in the graph.
image

I am leaving on a trip soon but will add Conv1d and try again when i get back. Thanks for being so helpful!

@tianyic
Copy link
Owner

tianyic commented Mar 16, 2024

@iamanigeeit

Take your time and enjoy your trip.

Regarding conv1d operator support, that sounds great. And I would appreciate if you could create a pull request to add that operator into the repo once you make it work. Thanks!

Regarding weight sharing, due to string decoding reason, the params_no_tensor.txt has some noisy strings on my end preventing me reading clearly.

Anyway, if the weight sharing is caused by one nn.module being called multiple times during the forward pass, we typically use OTO on the specific sub-modules rather than the full one. Meanwhile, it is better ensure the optimizer.parameter list covers the full model's parameter lists.

Here is one rough example that may help

class FullModel(nn.Module):
     def __init__(self):
          self.encoder = XX
          self.decoder = XX
    
     def forward(self, x):
          return self.decoder(self.encoder(self.encoder(x)))

oto_encoder = OTO(model.encoder, dummy_input_encoder)
optimizer_encoder = oto_encoder.hesso()

oto_decoder = OTO(model.decoder, dummy_input_decoder)
optimizer_decoder = oto_decoder.hesso()

# optimizer_encoder and optimizer_decoder cover all parameters in FullModel.

# Training as normal but via two optimizers
optimizer_encoder.step()
optimizer_decoder.step()

# After training
oto_encoder.construct_subnet()
oto_decoder.construct_subnet()

@iamanigeeit
Copy link
Contributor Author

@tianyic

I don't understand what you mean by

please discard the lines related to grouped conv2d. Just consider it as groups=1.

Do you mean i should do this?

    def set_num_groups(self):
        self.num_groups = 1

But this looks correct because param.shape[0] is the number of filters for 1-D convolution.

        for param_name in self.name_to_param:
            param = self.name_to_param[param_name]
            self.num_groups = max(self.num_groups, param.shape[0])

Actually i don't think there is any difference between Conv1dOTO and Conv2dOTO except for compute_flops.

@tianyic
Copy link
Owner

tianyic commented Apr 1, 2024

@iamanigeeit

We are on the same page, yet may refer to different groups. I am referring to the grouped conv, see the groups argument in Conv2d torch. The num_groups in Conv2dOTO is the number of ZIGs.

@iamanigeeit
Copy link
Contributor Author

@tianyic

I've implemented Conv1D / Conv3D as an extension of Conv2D, but looks like the problem is with LayerNorm and MultiHeadAttention.

The repeated params are:

# LayerNorm
tts.pitch_predictor.conv.{0,1,2,3,4}.2.{weight,bias}
tts.energy_predictor.conv.{0,1}.2.{weight,bias}
tts.duration_predictor.conv.{0,1}.2.{weight,bias}
tts.{encoder,decoder}.encoders.{0,1,2,3}.norm{1,2}.{weight,bias}
tts.{encoder,decoder}.after_norm.{weight,bias}

# MultiHeadAttention output linear layer
tts.{encoder,decoder}.encoders.{0,1,2,3}.self_attn.linear_out.{weight,bias}

I am not sure about the LayerNorm and am debugging through the graph init to see what's happening. However, for MultiHeadAttention, i think i need to subclass it because there is a final linear layer. But i don't know what to do there.

Regarding weight sharing, due to string decoding reason, the params_no_tensor.txt has some noisy strings on my end preventing me reading clearly.

I didn't explain properly... that was a pickle file renamed to txt as Github does not allow pkl. Here it is in zip form (you have to load it with pickle.load)
params_no_tensor.pkl.zip

Minor issue:

https://github.com/tianyic/only_train_once/blob/786a2033a5f0730e8298bba5a4e1e6c09777b687/only_train_once/graph/graph.py#L852-L853

I have nodes where node.input_shape = [] and it causes error at node.input_shape[0]. Can i just skip it?

@tianyic
Copy link
Owner

tianyic commented Apr 23, 2024

@iamanigeeit

Happy to hear that you implemented Conv1d and Conv3d, that is great. I will look into your case after completing some business stuffs later, and loop back soon.

@tianyic
Copy link
Owner

tianyic commented Apr 23, 2024

I see.

Regarding this weight sharing case, I suggest to directly override the modules with the repeating params. For example,

repeated_weight, repeated_bias here

model.encoder.encoders.0.self_attn.linear_out = nn.Linear(**kwargs)
model.encoder.encoders.0.self_attn.linear_out.weight.data.copy_(repeated_weight.data)
model.encoder.encoders.0.self_attn.linear_out.bias.data.copy_(repeated_bias.data)

model.encoder.encoders.1.self_attn.linear_out = nn.Linear(**kwargs)
model.encoder.encoders.1.self_attn.linear_out.weight.data.copy_(repeated_weight.data)
model.encoder.encoders.1.self_attn.linear_out.bias.data.copy_(repeated_bias.data)

Afterwards, the issue could be resolved, though the parameter sizes might slightly increase before pruning.

Another way to resolve it is to make the key, query, value matrices across varying layers pruned all together to ensure the the shared linear_out has consistent shape dimension. But this may introduce more engineering works to manipulate the param_groups in the optimizer at least.

Regarding FLOPs, yes, please skip that, compute_flops is optional.

@iamanigeeit
Copy link
Contributor Author

@tianyic

Sorry I can't understand what you mean... should I just copy out all the module data before the graph init and put them back?

Also, apologies for the confusion -- there are no shared weights, only shared modules. The encoder and decoder are the same architecture but instantiated separately.

encoder = TransformerEncoder(
                idim=idim,
                attention_dim=adim,
                attention_heads=aheads,
                linear_units=eunits,
                num_blocks=elayers,
                input_layer=encoder_input_layer,
                dropout_rate=transformer_enc_dropout_rate,
                positional_dropout_rate=transformer_enc_positional_dropout_rate,
                attention_dropout_rate=transformer_enc_attn_dropout_rate,
                pos_enc_class=pos_enc_class,
                normalize_before=encoder_normalize_before,
                concat_after=encoder_concat_after,
                positionwise_layer_type=positionwise_layer_type,
                positionwise_conv_kernel_size=positionwise_conv_kernel_size,
            )
decoder = TransformerEncoder(
                idim=0,
                attention_dim=adim,
                attention_heads=aheads,
                linear_units=dunits,
                num_blocks=dlayers,
                input_layer=None,
                dropout_rate=transformer_dec_dropout_rate,
                positional_dropout_rate=transformer_dec_positional_dropout_rate,
                attention_dropout_rate=transformer_dec_attn_dropout_rate,
                pos_enc_class=pos_enc_class,
                normalize_before=decoder_normalize_before,
                concat_after=decoder_concat_after,
                positionwise_layer_type=positionwise_layer_type,
                positionwise_conv_kernel_size=positionwise_conv_kernel_size,
            )

I see that you have LlamaAttention and BertAttention, so maybe I also have to create my own operator?

PS: I emailed you, maybe we can discuss on Teams?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants