Skip to content

Commit

Permalink
add missing function
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Sep 25, 2024
1 parent d2b655a commit 59360be
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions hf_olmo/configuration_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@ def num_hidden_layers(self):
def hidden_size(self):
return self.d_model

@property
def effective_n_kv_heads(self) -> int:
if self.n_kv_heads is None:
if self.multi_query_attention is True:
return 1
else:
return self.n_heads
else:
if self.multi_query_attention is None:
return self.n_kv_heads
if self.multi_query_attention:
n_kv_heads_should_be = 1
else:
n_kv_heads_should_be = self.n_heads
if self.n_kv_heads == n_kv_heads_should_be:
return n_kv_heads_should_be
else:
raise OLMoConfigurationError(
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
)


# Register the config class so that it is available for transformer pipelines, auto-loading etc.
# OLMo is integrated directly in transformers from v4.40.0 onwards, but the version in transformers
Expand Down

0 comments on commit 59360be

Please sign in to comment.