diff --git a/lit_gpt/config.py b/lit_gpt/config.py index 26a875da2f..7778d708fd 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -866,6 +866,7 @@ def norm_class(self) -> Type: bias=False, _norm_class="RMSNorm", _mlp_class="GemmaMLP", + gelu_approximate="tanh", intermediate_size=16384, ), # https://huggingface.co/google/gemma-7b/blob/main/config.json @@ -884,6 +885,7 @@ def norm_class(self) -> Type: bias=False, _norm_class="RMSNorm", _mlp_class="GemmaMLP", + gelu_approximate="tanh", intermediate_size=24576, ), ] diff --git a/lit_gpt/model.py b/lit_gpt/model.py index ec7103bcd3..29f6f387f4 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -287,6 +287,8 @@ def __init__(self, config: Config) -> None: self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + self.config = config + def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) @@ -298,7 +300,7 @@ class GemmaMLP(LLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - x = torch.nn.functional.gelu(x_fc_1) * x_fc_2 + x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2 return self.proj(x) diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index 7e0fca8c63..ce21f176e4 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -414,6 +414,7 @@ def test_against_original_gemma(model_name, device, dtype): rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, + hidden_act="gelu_pytorch_tanh", ) assert ours_config.intermediate_size == theirs_config.intermediate_size diff --git a/tests/test_model.py b/tests/test_model.py index cfb5d8838d..3a7d5bcea3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -596,6 +596,7 @@ def test_against_original_gemma(model_name, device, dtype): rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, + hidden_act="gelu_pytorch_tanh", ) assert ours_config.intermediate_size == theirs_config.intermediate_size