diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py index 349279de6027..a4184a2c3968 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py @@ -369,9 +369,16 @@ def _add_adapters_to_component(self, component, layer_cfg, adapter_name_keys): in_features=layer_cfg.ffn_hidden_size // layer_cfg.tensor_model_parallel_size ) else: - cfg = InfusedAdapterConfig( - in_features=layer_cfg.hidden_size // layer_cfg.tensor_model_parallel_size - ) + if layer_cfg.get('kv_channels', None): + cfg = InfusedAdapterConfig( + in_features=layer_cfg.kv_channels + * layer_cfg.num_attention_heads + // layer_cfg.tensor_model_parallel_size + ) + else: + cfg = InfusedAdapterConfig( + in_features=layer_cfg.hidden_size // layer_cfg.tensor_model_parallel_size + ) module.add_adapter(name=adapter_key, cfg=cfg) def _component_state_dict(self, component_name, component, adapter_name_keys):