Skip to content

Commit a72f72d

Browse files
committed
Fix Gemma3 and Gemma2 flops computation.
1 parent 2adc3ba commit a72f72d

File tree

1 file changed

+52
-2
lines changed

1 file changed

+52
-2
lines changed

MaxText/maxtext_utils.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,59 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo
171171
* config.head_dim
172172
)
173173
attention_tflops = attention_flops * config.num_decoder_layers * 3 / 10**12
174+
causal_attention_tflops = attention_tflops / 2
174175

175176
# multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer
176177
learnable_weight_tflops = (
177178
((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers * 2 + embedding_flops) * 3 / 10**12
178179
)
179180

180-
return attention_tflops, learnable_weight_tflops
181+
return causal_attention_tflops, learnable_weight_tflops
182+
183+
184+
def calculate_gemma3_tflops_training_per_device(config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops):
185+
"""
186+
Calculate training TFLOPs for Gemma3, which has an alternating pattern of
187+
5 local attention layers and 1 global attention layer.
188+
"""
189+
num_layers = config.num_decoder_layers
190+
191+
num_global_layers = num_layers // 6
192+
num_local_layers = num_layers - num_global_layers
193+
194+
# FLOPs for a single global attention layer (full attention)
195+
# Formula: 4 * batch_size * seq_len^2 * num_heads * head_dim
196+
global_attention_flops_per_layer = (
197+
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
198+
)
199+
200+
# FLOPs for a single local attention layer (sliding window)
201+
# Formula: 4 * batch_size * seq_len * window_size * num_heads * head_dim
202+
local_attention_flops_per_layer = (
203+
4
204+
* config.per_device_batch_size
205+
* config.max_target_length
206+
* min(config.sliding_window_size, config.max_target_length)
207+
* config.num_query_heads
208+
* config.head_dim
209+
)
210+
211+
# Total attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local)
212+
total_attention_flops = (
213+
num_global_layers * global_attention_flops_per_layer +
214+
num_local_layers * local_attention_flops_per_layer
215+
)
216+
217+
# Convert to TFLOPs and multiply by 3 for fwd/bwd pass
218+
attention_tflops = total_attention_flops * 3 / 10**12
219+
causal_attention_tflops = attention_tflops / 2
220+
221+
# Learnable weights (FFN, QKV, Projections) are present in every layer.
222+
learnable_weight_tflops = (
223+
((total_ffn_flops + qkv_flops + projection_flops) * num_layers + embedding_flops) * 3 / 10**12
224+
)
225+
226+
return causal_attention_tflops, learnable_weight_tflops
181227

182228

183229
def calculate_mla_tflops_per_device(config):
@@ -304,6 +350,10 @@ def calculate_tflops_training_per_device(config, log=True):
304350
attention_tflops, learnable_weight_tflops = calculate_gemma2_tflops_training_per_device(
305351
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops
306352
)
353+
elif config.decoder_block == DecoderBlockType.GEMMA3:
354+
attention_tflops, learnable_weight_tflops = calculate_gemma3_tflops_training_per_device(
355+
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops
356+
)
307357
elif config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4):
308358
learnable_weight_tflops = (
309359
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
@@ -1080,7 +1130,7 @@ def get_formatted_sharding_annotations(params, mesh=None):
10801130
spec_parts = []
10811131
for item in p_leaf.sharding.spec:
10821132
# Represent None as "Replicated" to make it explicit.
1083-
spec_parts.append(str(item) if item is not None else "Relicated")
1133+
spec_parts.append(str(item) if item is not None else "Replicated")
10841134
sharding_desc = f"PartitionSpec({', '.join(spec_parts)})"
10851135
# Case 2: The parameter is explicitly marked as fully replicated.
10861136
elif hasattr(p_leaf.sharding, "spec") and p_leaf.sharding.spec is None:

0 commit comments

Comments
 (0)