@@ -230,15 +230,17 @@ def verify_with_parallel_config(
230
230
self ,
231
231
parallel_config : "ParallelConfig" ,
232
232
) -> None :
233
- total_num_attention_heads = self .hf_text_config .num_attention_heads
233
+ total_num_attention_heads = getattr (self .hf_text_config ,
234
+ "num_attention_heads" , 0 )
234
235
tensor_parallel_size = parallel_config .tensor_parallel_size
235
236
if total_num_attention_heads % tensor_parallel_size != 0 :
236
237
raise ValueError (
237
238
f"Total number of attention heads ({ total_num_attention_heads } )"
238
239
" must be divisible by tensor parallel size "
239
240
f"({ tensor_parallel_size } )." )
240
241
241
- total_num_hidden_layers = self .hf_text_config .num_hidden_layers
242
+ total_num_hidden_layers = getattr (self .hf_text_config ,
243
+ "num_hidden_layers" , 0 )
242
244
pipeline_parallel_size = parallel_config .pipeline_parallel_size
243
245
if total_num_hidden_layers % pipeline_parallel_size != 0 :
244
246
raise ValueError (
@@ -341,8 +343,8 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
341
343
342
344
def get_num_attention_heads (self ,
343
345
parallel_config : "ParallelConfig" ) -> int :
344
- return self .hf_text_config . num_attention_heads // \
345
- parallel_config .tensor_parallel_size
346
+ num_heads = getattr ( self .hf_text_config , " num_attention_heads" , 0 )
347
+ return num_heads // parallel_config .tensor_parallel_size
346
348
347
349
def get_num_layers (self , parallel_config : "ParallelConfig" ) -> int :
348
350
total_num_hidden_layers = self .hf_text_config .num_hidden_layers
@@ -818,7 +820,8 @@ def maybe_create_spec_config(
818
820
speculative_model (Optional[str]): The name of the speculative
819
821
model, if provided.
820
822
num_speculative_tokens (Optional[int]): The number of speculative
821
- tokens, if provided.
823
+ tokens, if provided. Will default to the number in the draft
824
+ model config if present, otherwise is required.
822
825
speculative_max_model_len (Optional[int]): The maximum model len of
823
826
the speculative model. Used when testing the ability to skip
824
827
speculation for some sequences.
@@ -841,24 +844,18 @@ def maybe_create_spec_config(
841
844
the necessary conditions are met, else None.
842
845
"""
843
846
844
- if speculative_model is None and num_speculative_tokens is None :
847
+ if speculative_model is None :
848
+ if num_speculative_tokens is not None :
849
+ raise ValueError ("num_speculative_tokens was provided without "
850
+ "speculative_model." )
845
851
return None
846
852
847
- if speculative_model is not None and num_speculative_tokens is None :
848
- raise ValueError (
849
- "Expected both speculative_model and "
850
- "num_speculative_tokens to be provided, but found "
851
- f"{ speculative_model = } and { num_speculative_tokens = } ." )
852
-
853
853
if (speculative_disable_by_batch_size is not None
854
854
and speculative_disable_by_batch_size < 2 ):
855
855
raise ValueError ("Expect the batch size threshold of disabling "
856
856
"speculative decoding is > 1, but got "
857
857
f"{ speculative_disable_by_batch_size = } " )
858
858
859
- assert (speculative_model is not None
860
- and num_speculative_tokens is not None )
861
-
862
859
if enable_chunked_prefill :
863
860
raise ValueError (
864
861
"Speculative decoding and chunked prefill are "
@@ -912,6 +909,27 @@ def maybe_create_spec_config(
912
909
max_logprobs = target_model_config .max_logprobs ,
913
910
)
914
911
912
+ if (draft_model_config .hf_config .model_type == "mlp_speculator"
913
+ and target_parallel_config .world_size != 1 ):
914
+ # MLPSpeculator TP support will be added very soon
915
+ raise ValueError (
916
+ "Speculative decoding with mlp_speculator models does not "
917
+ "yet support distributed inferencing (TP > 1)." )
918
+
919
+ n_predict = getattr (draft_model_config .hf_config , "n_predict" ,
920
+ None )
921
+ if n_predict is not None :
922
+ if num_speculative_tokens is None :
923
+ # Default to max value defined in draft model config.
924
+ num_speculative_tokens = n_predict
925
+ elif num_speculative_tokens > n_predict :
926
+ # Verify provided value doesn't exceed the maximum
927
+ # supported by the draft model.
928
+ raise ValueError (
929
+ "Expected both speculative_model and "
930
+ "num_speculative_tokens to be provided, but found "
931
+ f"{ speculative_model = } and { num_speculative_tokens = } ." )
932
+
915
933
draft_model_config .max_model_len = (
916
934
SpeculativeConfig ._maybe_override_draft_max_model_len (
917
935
speculative_max_model_len ,
@@ -923,6 +941,12 @@ def maybe_create_spec_config(
923
941
SpeculativeConfig .create_draft_parallel_config (
924
942
target_parallel_config ))
925
943
944
+ if num_speculative_tokens is None :
945
+ raise ValueError (
946
+ "num_speculative_tokens must be provided with "
947
+ "speculative_model unless the draft model config contains an "
948
+ "n_predict parameter." )
949
+
926
950
return SpeculativeConfig (
927
951
draft_model_config ,
928
952
draft_parallel_config ,
0 commit comments