diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4a3fae73824..6d0b3285162c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -873,11 +873,6 @@ def create_engine_config(self, ) -> EngineConfig: raise ValueError( f"Invalid module {m} in collect_detailed_traces. " f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}") - if (m == "model" - or m == "all") and self.pipeline_parallel_size > 1: - raise ValueError( - "Collection of detailed traces for the 'model' module is " - "not yet supported with pipeline parallelism.") observability_config = ObservabilityConfig( otlp_traces_endpoint=self.otlp_traces_endpoint, collect_model_forward_time="model" in detailed_trace_modules diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cfbbb6698cd8..a541831ab460 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1551,6 +1551,21 @@ def execute_model( # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: + if (self.is_driver_worker + and hidden_or_intermediate_states is not None + and isinstance(hidden_or_intermediate_states, + IntermediateTensors) + and self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + hidden_or_intermediate_states.tensors["model_forward_time"] = ( + torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states logits = self.model.compute_logits(hidden_or_intermediate_states, @@ -1570,11 +1585,16 @@ def execute_model( model_forward_end.synchronize() model_forward_time = model_forward_start.elapsed_time( model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() # If there are multiple workers, we are still tracking the latency # from the start time of the driver worker to the end time of the # driver worker. The model forward time will then end up covering # the communication time as well. - output.model_forward_time = model_forward_time + output.model_forward_time = (orig_model_forward_time + + model_forward_time) if self.return_hidden_states: # we only need to pass hidden states of most recent token @@ -1732,7 +1752,7 @@ def forward( **kwargs) if intermediate_tensors is not None: for key in intermediate_tensors.tensors: - if key != "model_execute_time": + if key != "model_execute_time" and key != "model_forward_time": self.input_buffers[key].copy_(intermediate_tensors[key], non_blocking=True) # Run the graph.