@@ -171,13 +171,59 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo
171
171
* config .head_dim
172
172
)
173
173
attention_tflops = attention_flops * config .num_decoder_layers * 3 / 10 ** 12
174
+ causal_attention_tflops = attention_tflops / 2
174
175
175
176
# multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer
176
177
learnable_weight_tflops = (
177
178
((total_ffn_flops + qkv_flops + projection_flops ) * config .num_decoder_layers * 2 + embedding_flops ) * 3 / 10 ** 12
178
179
)
179
180
180
- return attention_tflops , learnable_weight_tflops
181
+ return causal_attention_tflops , learnable_weight_tflops
182
+
183
+
184
+ def calculate_gemma3_tflops_training_per_device (config , total_ffn_flops , qkv_flops , projection_flops , embedding_flops ):
185
+ """
186
+ Calculate training TFLOPs for Gemma3, which has an alternating pattern of
187
+ 5 local attention layers and 1 global attention layer.
188
+ """
189
+ num_layers = config .num_decoder_layers
190
+
191
+ num_global_layers = num_layers // 6
192
+ num_local_layers = num_layers - num_global_layers
193
+
194
+ # FLOPs for a single global attention layer (full attention)
195
+ # Formula: 4 * batch_size * seq_len^2 * num_heads * head_dim
196
+ global_attention_flops_per_layer = (
197
+ 4 * config .per_device_batch_size * config .max_target_length ** 2 * config .num_query_heads * config .head_dim
198
+ )
199
+
200
+ # FLOPs for a single local attention layer (sliding window)
201
+ # Formula: 4 * batch_size * seq_len * window_size * num_heads * head_dim
202
+ local_attention_flops_per_layer = (
203
+ 4
204
+ * config .per_device_batch_size
205
+ * config .max_target_length
206
+ * min (config .sliding_window_size , config .max_target_length )
207
+ * config .num_query_heads
208
+ * config .head_dim
209
+ )
210
+
211
+ # Total attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local)
212
+ total_attention_flops = (
213
+ num_global_layers * global_attention_flops_per_layer +
214
+ num_local_layers * local_attention_flops_per_layer
215
+ )
216
+
217
+ # Convert to TFLOPs and multiply by 3 for fwd/bwd pass
218
+ attention_tflops = total_attention_flops * 3 / 10 ** 12
219
+ causal_attention_tflops = attention_tflops / 2
220
+
221
+ # Learnable weights (FFN, QKV, Projections) are present in every layer.
222
+ learnable_weight_tflops = (
223
+ ((total_ffn_flops + qkv_flops + projection_flops ) * num_layers + embedding_flops ) * 3 / 10 ** 12
224
+ )
225
+
226
+ return causal_attention_tflops , learnable_weight_tflops
181
227
182
228
183
229
def calculate_mla_tflops_per_device (config ):
@@ -304,6 +350,10 @@ def calculate_tflops_training_per_device(config, log=True):
304
350
attention_tflops , learnable_weight_tflops = calculate_gemma2_tflops_training_per_device (
305
351
config , total_ffn_flops , qkv_flops , projection_flops , embedding_flops
306
352
)
353
+ elif config .decoder_block == DecoderBlockType .GEMMA3 :
354
+ attention_tflops , learnable_weight_tflops = calculate_gemma3_tflops_training_per_device (
355
+ config , total_ffn_flops , qkv_flops , projection_flops , embedding_flops
356
+ )
307
357
elif config .decoder_block in (DecoderBlockType .DEEPSEEK , DecoderBlockType .LLAMA4 ):
308
358
learnable_weight_tflops = (
309
359
(total_ffn_flops + (qkv_flops + projection_flops ) * config .num_decoder_layers + embedding_flops ) * 3 / 10 ** 12
@@ -1080,7 +1130,7 @@ def get_formatted_sharding_annotations(params, mesh=None):
1080
1130
spec_parts = []
1081
1131
for item in p_leaf .sharding .spec :
1082
1132
# Represent None as "Replicated" to make it explicit.
1083
- spec_parts .append (str (item ) if item is not None else "Relicated " )
1133
+ spec_parts .append (str (item ) if item is not None else "Replicated " )
1084
1134
sharding_desc = f"PartitionSpec({ ', ' .join (spec_parts )} )"
1085
1135
# Case 2: The parameter is explicitly marked as fully replicated.
1086
1136
elif hasattr (p_leaf .sharding , "spec" ) and p_leaf .sharding .spec is None :
0 commit comments