diff --git a/hf_olmo/configuration_olmo.py b/hf_olmo/configuration_olmo.py index 5bef2bd11..39e5851c8 100644 --- a/hf_olmo/configuration_olmo.py +++ b/hf_olmo/configuration_olmo.py @@ -36,6 +36,27 @@ def num_hidden_layers(self): def hidden_size(self): return self.d_model + @property + def effective_n_kv_heads(self) -> int: + if self.n_kv_heads is None: + if self.multi_query_attention is True: + return 1 + else: + return self.n_heads + else: + if self.multi_query_attention is None: + return self.n_kv_heads + if self.multi_query_attention: + n_kv_heads_should_be = 1 + else: + n_kv_heads_should_be = self.n_heads + if self.n_kv_heads == n_kv_heads_should_be: + return n_kv_heads_should_be + else: + raise OLMoConfigurationError( + "You can't set `multi_query_attention` and `n_kv_heads` at the same time." + ) + # Register the config class so that it is available for transformer pipelines, auto-loading etc. # OLMo is integrated directly in transformers from v4.40.0 onwards, but the version in transformers