From e3077991f4f06cb5a3d9172a65833e96055b2eec Mon Sep 17 00:00:00 2001 From: Adi Renduchintala <108822655+arendu@users.noreply.github.com> Date: Mon, 3 Oct 2022 19:35:19 -0700 Subject: [PATCH] [bug_fix] kv_channels is used when available (#5066) * fix bug s.t kv_channels is used when available Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: arendu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../language_modeling/megatron_t5_adapter_model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py index 349279de6027..a4184a2c3968 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py @@ -369,9 +369,16 @@ def _add_adapters_to_component(self, component, layer_cfg, adapter_name_keys): in_features=layer_cfg.ffn_hidden_size // layer_cfg.tensor_model_parallel_size ) else: - cfg = InfusedAdapterConfig( - in_features=layer_cfg.hidden_size // layer_cfg.tensor_model_parallel_size - ) + if layer_cfg.get('kv_channels', None): + cfg = InfusedAdapterConfig( + in_features=layer_cfg.kv_channels + * layer_cfg.num_attention_heads + // layer_cfg.tensor_model_parallel_size + ) + else: + cfg = InfusedAdapterConfig( + in_features=layer_cfg.hidden_size // layer_cfg.tensor_model_parallel_size + ) module.add_adapter(name=adapter_key, cfg=cfg) def _component_state_dict(self, component_name, component, adapter_name_keys):