From 8d8d4e27cb918b47e83343547190dccd20dcfde4 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Mon, 6 May 2024 18:05:17 +0300 Subject: [PATCH 01/23] Add OpenTelemetry support --- examples/production_monitoring/Otel.md | 69 +++++++++++++++++++ .../production_monitoring/dummy_client.py | 35 ++++++++++ requirements-common.txt | 3 + vllm/engine/async_llm_engine.py | 13 ++++ vllm/engine/llm_engine.py | 66 ++++++++++++++++-- vllm/entrypoints/openai/serving_chat.py | 5 ++ vllm/entrypoints/openai/serving_completion.py | 5 ++ vllm/sequence.py | 4 ++ vllm/tracing.py | 18 +++++ 9 files changed, 213 insertions(+), 5 deletions(-) create mode 100644 examples/production_monitoring/Otel.md create mode 100644 examples/production_monitoring/dummy_client.py create mode 100644 vllm/tracing.py diff --git a/examples/production_monitoring/Otel.md b/examples/production_monitoring/Otel.md new file mode 100644 index 000000000000..9838d1fc5444 --- /dev/null +++ b/examples/production_monitoring/Otel.md @@ -0,0 +1,69 @@ +# Setup OpenTelemetry POC + +1. Start Jaeger in a docker container: + ``` + # From: https://www.jaegertracing.io/docs/1.57/getting-started/ + docker run --rm --name jaeger \ + -e COLLECTOR_ZIPKIN_HOST_PORT=:9411 \ + -p 6831:6831/udp \ + -p 6832:6832/udp \ + -p 5778:5778 \ + -p 16686:16686 \ + -p 4317:4317 \ + -p 4318:4318 \ + -p 14250:14250 \ + -p 14268:14268 \ + -p 14269:14269 \ + -p 9411:9411 \ + jaegertracing/all-in-one:1.57 + ``` + +1. In a new shell, export Jaeger IP: + ``` + export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) + export OTEL_EXPORTER_OTLP_ENDPOINT=http://$JAEGER_IP:4318 + ``` + Then set vLLM's service name for OpenTelemetry and run vLLM: + ``` + export OTEL_SERVICE_NAME="vllm-server" + python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" + ``` + +1. In a new shell, send requests with trace context from a dummy client + ``` + export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) + export OTEL_EXPORTER_OTLP_ENDPOINT=http://$JAEGER_IP:4318 + export OTEL_SERVICE_NAME="client-service" + python dummy_client.py + ``` + +1. Open Jaeger webui: http://localhost:16686/ + + In the search pane, select `vllm-server` service and hit `Find Traces`. You should get a list of traces, one for each request. + ![Traces](https://i.imgur.com/GYHhFjo.png) + +1. Clicking on a trace will show its spans and their tags. In this demo, each trace has 2 spans. One from the dummy client containing the prompt text and one from vLLM containing metadata about the request. +![Spans details](https://i.imgur.com/OPf6CBL.png) + + +## Disabling tracing +OpenTelemetry tracing can be disabled by setting the environment variable: +``` +export OTEL_SDK_DISABLED=true +``` + +## Instrumentation of FastAPI +OpenTelemetry allows automatic instrumentation of FastAPI. +1. Install the instrumentation library + ``` + pip install opentelemetry-instrumentation-fastapi + ``` + +1. Run vLLM with `opentelemetry-instrument` + ``` + opentelemetry-instrument python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" + ``` + +1. Send a request to vLLM and find its trace in Jaeger. It should contain spans from FastAPI. + +![FastAPI Spans](https://i.imgur.com/hywvoOJ.png) \ No newline at end of file diff --git a/examples/production_monitoring/dummy_client.py b/examples/production_monitoring/dummy_client.py new file mode 100644 index 000000000000..2c1c703d1504 --- /dev/null +++ b/examples/production_monitoring/dummy_client.py @@ -0,0 +1,35 @@ +import requests +from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter) +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import (BatchSpanProcessor, + ConsoleSpanExporter) +from opentelemetry.trace import SpanKind, set_tracer_provider +from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator) + +trace_provider = TracerProvider() +set_tracer_provider(trace_provider) + +trace_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter())) +trace_provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter())) + +tracer = trace_provider.get_tracer("dummy-client") + +url = "http://localhost:8000/v1/completions" +with tracer.start_as_current_span("client-span", kind=SpanKind.CLIENT) as span: + prompt = "San Francisco is a" + span.set_attribute("prompt", prompt) + headers = {} + TraceContextTextMapPropagator().inject(headers) + payload = { + "model": "facebook/opt-125m", + "prompt": prompt, + "max_tokens": 10, + "best_of": 20, + "n": 3, + "use_beam_search": "true", + "temperature": 0.0, + # "stream": True, + } + response = requests.post(url, headers=headers, json=payload) diff --git a/requirements-common.txt b/requirements-common.txt index 32e2ebe8c615..8612dffdd3e5 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,3 +20,6 @@ lm-format-enforcer == 0.10.1 outlines >= 0.0.43 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +opentelemetry-sdk +opentelemetry-api +opentelemetry-exporter-otlp diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 03b6d03a9fde..61e7db6ecfc5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -4,6 +4,7 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union) +from opentelemetry.context.context import Context from transformers import PreTrainedTokenizer import vllm.envs as envs @@ -285,6 +286,7 @@ async def add_request_async( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + trace_context: Optional[Context] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -301,6 +303,7 @@ async def add_request_async( params=params, arrival_time=arrival_time, lora_request=lora_request, + trace_context=trace_context, ) async def check_health_async(self) -> None: @@ -545,6 +548,7 @@ async def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + trace_context: Optional[Context] = None, ) -> AsyncStream: if self.log_requests: if isinstance(inputs, str): @@ -586,6 +590,7 @@ async def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, + trace_context=trace_context, ) return stream @@ -596,6 +601,7 @@ async def generate( sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + trace_context: Optional[Context] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -610,6 +616,7 @@ async def generate( sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. + trace_context: OpenTelemetry trace context. Yields: The output `RequestOutput` objects from the LLMEngine @@ -663,6 +670,7 @@ async def generate( inputs, sampling_params, lora_request=lora_request, + trace_context=trace_context, ): yield LLMEngine.validate_output(output, RequestOutput) @@ -672,6 +680,7 @@ async def encode( pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + trace_context: Optional[Context] = None, ) -> AsyncIterator[EmbeddingRequestOutput]: """Generate outputs for a request from an embedding model. @@ -686,6 +695,7 @@ async def encode( pooling_params: The pooling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. + trace_context: OpenTelemetry trace context. Yields: The output `EmbeddingRequestOutput` objects from the LLMEngine @@ -737,6 +747,7 @@ async def encode( inputs, pooling_params, lora_request=lora_request, + trace_context=trace_context, ): yield LLMEngine.validate_output(output, EmbeddingRequestOutput) @@ -747,6 +758,7 @@ async def _process_request( params: Union[SamplingParams, PoolingParams], *, lora_request: Optional[LoRARequest] = None, + trace_context: Optional[Context] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" @@ -758,6 +770,7 @@ async def _process_request( params, arrival_time=arrival_time, lora_request=lora_request, + trace_context=trace_context, ) try: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fd64337d4384..bf8f7cb544f8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,6 +4,8 @@ from typing import Sequence as GenericSequence from typing import Set, Type, TypeVar, Union +from opentelemetry.context.context import Context +from opentelemetry.trace import SpanKind from transformers import GenerationConfig, PreTrainedTokenizer from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, @@ -31,6 +33,7 @@ PoolerOutput, SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) +from vllm.tracing import init_tracer from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -42,6 +45,8 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 +tracer = init_tracer("vllm.llm_engine") + def _load_generation_config_dict(model_config: ModelConfig): try: @@ -436,6 +441,7 @@ def _add_processed_request( params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], + trace_context: Optional[Context] = None, ) -> None: # Create the sequences. block_size = self.cache_config.block_size @@ -453,6 +459,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, + trace_context=trace_context, ) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( @@ -499,6 +506,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + trace_context: Optional[Context] = None, ) -> None: """Add a request to the engine's request pool. @@ -516,6 +524,7 @@ def add_request( :class:`~vllm.PoolingParams` for pooling. arrival_time: The arrival time of the request. If None, we use the current monotonic time. + trace_context: OpenTelemetry trace context. Details: - Set arrival_time to the current time if it is None. @@ -557,6 +566,7 @@ def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, + trace_context=trace_context, ) def _create_sequence_group_with_sampling( @@ -566,6 +576,7 @@ def _create_sequence_group_with_sampling( sampling_params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest], + trace_context: Optional[Context] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -587,11 +598,14 @@ def _create_sequence_group_with_sampling( self.generation_config_fields) # Create the sequence group. - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - sampling_params=sampling_params, - lora_request=lora_request) + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + trace_context=trace_context, + ) return seq_group @@ -919,6 +933,8 @@ def _get_stats( for seq in seq_group.get_finished_seqs() ]) + self.create_trace_span(seq_group, now) + # Number of generation tokens. # num_batched_tokens equals the number of prompt_tokens plus the # number of decode_tokens in a single iteration. So, @@ -978,3 +994,43 @@ def list_loras(self) -> Set[int]: def check_health(self) -> None: self.model_executor.check_health() + + def create_trace_span(self, seq_group: SequenceGroup, now: float) -> None: + arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) + with tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=seq_group.trace_context, + start_time=arrival_time_nano_seconds) as seq_span: + metrics = seq_group.metrics + ttft = metrics.first_token_time - metrics.arrival_time + e2e_time = now - seq_group.metrics.arrival_time + # attribute names are based on + # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md + seq_span.set_attribute("gen_ai.response.model", + self.model_config.model) + seq_span.set_attribute("gen_ai.request.id", seq_group.request_id) + seq_span.set_attribute("gen_ai.request.temperature", + seq_group.sampling_params.temperature) + seq_span.set_attribute("gen_ai.request.top_p", + seq_group.sampling_params.top_p) + seq_span.set_attribute("gen_ai.request.max_tokens", + seq_group.sampling_params.max_tokens) + seq_span.set_attribute("gen_ai.request.best_of", + seq_group.sampling_params.best_of) + seq_span.set_attribute("gen_ai.request.n", + seq_group.sampling_params.n) + seq_span.set_attribute("gen_ai.usage.num_sequences", + seq_group.num_seqs()) + seq_span.set_attribute("gen_ai.usage.prompt_tokens", + len(seq_group.prompt_token_ids)) + seq_span.set_attribute( + "gen_ai.usage.completion_tokens", + sum([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ])) + seq_span.set_attribute("gen_ai.latency.time_in_queue", + seq_group.metrics.time_in_queue) + seq_span.set_attribute("gen_ai.latency.time_to_first_token", ttft) + seq_span.set_attribute("gen_ai.latency.e2e", e2e_time) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 76940612496a..eaaafdf163d0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,6 +9,8 @@ from fastapi import Request from openai.types.chat import (ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam) +from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator) from vllm.config import ModelConfig, VisionLanguageConfig from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -267,11 +269,14 @@ async def create_chat_completion( if image_data is not None: inputs["multi_modal_data"] = image_data + trace_context = TraceContextTextMapPropagator().extract( + raw_request.headers if raw_request else {}) result_generator = self.engine.generate( inputs, sampling_params, request_id, lora_request, + trace_context=trace_context, ) # Streaming response if request.stream: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 64671e21a724..46d03e4dd722 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -5,6 +5,8 @@ from typing import Tuple from fastapi import Request +from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator) from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -125,6 +127,8 @@ async def create_completion(self, request: CompletionRequest, truncate_prompt_tokens) prompt_ids, prompt_text = prompt_formats + trace_context = TraceContextTextMapPropagator().extract( + raw_request.headers) generator = self.engine.generate( { "prompt": prompt_text, @@ -133,6 +137,7 @@ async def create_completion(self, request: CompletionRequest, sampling_params, f"{request_id}-{i}", lora_request=lora_request, + trace_context=trace_context, ) generators.append(generator) diff --git a/vllm/sequence.py b/vllm/sequence.py index 54243bfb1e91..07862c873d8c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch +from opentelemetry.context.context import Context from vllm.block import LogicalTokenBlock from vllm.inputs import LLMInputs @@ -414,6 +415,7 @@ class SequenceGroup: for an embedding model. encoder_seq: Optional, the single encoder sequence. Should be None unless you are working with an encoder/decoder model. + trace_context: OpenTelemetry trace context. """ def __init__( @@ -426,6 +428,7 @@ def __init__( embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, + trace_context: Optional[Context] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -441,6 +444,7 @@ def __init__( self.embeddings = embeddings self.pooling_params = pooling_params self.encoder_seq = encoder_seq + self.trace_context = trace_context @property def prompt(self) -> Optional[str]: diff --git a/vllm/tracing.py b/vllm/tracing.py new file mode 100644 index 000000000000..13b3b69e8d3e --- /dev/null +++ b/vllm/tracing.py @@ -0,0 +1,18 @@ +from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter) +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.trace import set_tracer_provider + + +def init_tracer(instrumenting_module_name): + trace_provider = TracerProvider() + + # The endpoint of OTLPSpanExporter is set from envvars: + # OTEL_EXPORTER_OTLP_ENDPOINT + # OTEL_EXPORTER_OTLP_TRACES_ENDPOINT + trace_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter())) + set_tracer_provider(trace_provider) + + tracer = trace_provider.get_tracer(instrumenting_module_name) + return tracer From 2aa077def8684d7a62a7da2c8d8e7a364ab98b11 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Tue, 28 May 2024 14:43:51 +0300 Subject: [PATCH 02/23] Create a trace only when trace_context is not None --- vllm/engine/llm_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bf8f7cb544f8..511cc544f86f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -996,6 +996,8 @@ def check_health(self) -> None: self.model_executor.check_health() def create_trace_span(self, seq_group: SequenceGroup, now: float) -> None: + if seq_group.trace_context is None: + return arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) with tracer.start_as_current_span( "llm_request", From dd9ffcc9717e59de74190214fba0bf0b6162dd5e Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Wed, 29 May 2024 13:06:31 +0300 Subject: [PATCH 03/23] Use constants for span attribute names --- requirements-common.txt | 1 + vllm/engine/llm_engine.py | 31 +++++++++++++++++-------------- vllm/tracing.py | 13 +++++++++++++ 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 8612dffdd3e5..c1691d8a869c 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -23,3 +23,4 @@ filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 opentelemetry-sdk opentelemetry-api opentelemetry-exporter-otlp +opentelemetry-semantic-conventions-ai diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 511cc544f86f..63edfe55d58c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -33,7 +33,7 @@ PoolerOutput, SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.tracing import init_tracer +from vllm.tracing import SpanAttributes, init_tracer from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -999,6 +999,7 @@ def create_trace_span(self, seq_group: SequenceGroup, now: float) -> None: if seq_group.trace_context is None: return arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) + with tracer.start_as_current_span( "llm_request", kind=SpanKind.SERVER, @@ -1009,30 +1010,32 @@ def create_trace_span(self, seq_group: SequenceGroup, now: float) -> None: e2e_time = now - seq_group.metrics.arrival_time # attribute names are based on # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md - seq_span.set_attribute("gen_ai.response.model", + seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL, self.model_config.model) - seq_span.set_attribute("gen_ai.request.id", seq_group.request_id) - seq_span.set_attribute("gen_ai.request.temperature", + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID, + seq_group.request_id) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE, seq_group.sampling_params.temperature) - seq_span.set_attribute("gen_ai.request.top_p", + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P, seq_group.sampling_params.top_p) - seq_span.set_attribute("gen_ai.request.max_tokens", + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS, seq_group.sampling_params.max_tokens) - seq_span.set_attribute("gen_ai.request.best_of", + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF, seq_group.sampling_params.best_of) - seq_span.set_attribute("gen_ai.request.n", + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N, seq_group.sampling_params.n) - seq_span.set_attribute("gen_ai.usage.num_sequences", + seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES, seq_group.num_seqs()) - seq_span.set_attribute("gen_ai.usage.prompt_tokens", + seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, len(seq_group.prompt_token_ids)) seq_span.set_attribute( - "gen_ai.usage.completion_tokens", + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, sum([ seq.get_output_len() for seq in seq_group.get_finished_seqs() ])) - seq_span.set_attribute("gen_ai.latency.time_in_queue", + seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE, seq_group.metrics.time_in_queue) - seq_span.set_attribute("gen_ai.latency.time_to_first_token", ttft) - seq_span.set_attribute("gen_ai.latency.e2e", e2e_time) + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) + seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) diff --git a/vllm/tracing.py b/vllm/tracing.py index 13b3b69e8d3e..68fed611a79a 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -1,3 +1,4 @@ +import opentelemetry.semconv.ai from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( OTLPSpanExporter) from opentelemetry.sdk.trace import TracerProvider @@ -16,3 +17,15 @@ def init_tracer(instrumenting_module_name): tracer = trace_provider.get_tracer(instrumenting_module_name) return tracer + + +class SpanAttributes(opentelemetry.semconv.ai.SpanAttributes): + # The following span attribute names are added here because they are missing + # from the Semantic Conventions for LLM. + LLM_REQUEST_ID = "gen_ai.request.id" + LLM_REQUEST_BEST_OF = "gen_ai.request.best_of" + LLM_REQUEST_N = "gen_ai.request.n" + LLM_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences" + LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue" + LLM_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" + LLM_LATENCY_E2E = "gen_ai.latency.e2e" From 2ce871d44d74a59b45b449822ebfe2940536bf17 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Mon, 3 Jun 2024 16:54:31 +0300 Subject: [PATCH 04/23] Remove required dependency on opentelemetry --- requirements-common.txt | 4 -- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 6 +-- vllm/entrypoints/openai/serving_chat.py | 5 +- vllm/entrypoints/openai/serving_completion.py | 6 +-- vllm/sequence.py | 2 +- vllm/tracing.py | 47 +++++++++++++++---- 7 files changed, 48 insertions(+), 24 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index c1691d8a869c..32e2ebe8c615 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,7 +20,3 @@ lm-format-enforcer == 0.10.1 outlines >= 0.0.43 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 -opentelemetry-sdk -opentelemetry-api -opentelemetry-exporter-otlp -opentelemetry-semantic-conventions-ai diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 61e7db6ecfc5..034403d1506c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -4,7 +4,6 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union) -from opentelemetry.context.context import Context from transformers import PreTrainedTokenizer import vllm.envs as envs @@ -20,6 +19,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.tracing import Context from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 63edfe55d58c..7b75f5a03fd1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,8 +4,6 @@ from typing import Sequence as GenericSequence from typing import Set, Type, TypeVar, Union -from opentelemetry.context.context import Context -from opentelemetry.trace import SpanKind from transformers import GenerationConfig, PreTrainedTokenizer from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, @@ -33,7 +31,7 @@ PoolerOutput, SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.tracing import SpanAttributes, init_tracer +from vllm.tracing import Context, SpanAttributes, SpanKind, init_tracer from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -996,6 +994,8 @@ def check_health(self) -> None: self.model_executor.check_health() def create_trace_span(self, seq_group: SequenceGroup, now: float) -> None: + if tracer is None: + return if seq_group.trace_context is None: return arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index eaaafdf163d0..dd274ac9721b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,8 +9,6 @@ from fastapi import Request from openai.types.chat import (ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam) -from opentelemetry.trace.propagation.tracecontext import ( - TraceContextTextMapPropagator) from vllm.config import ModelConfig, VisionLanguageConfig from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -33,6 +31,7 @@ get_full_image_text_prompt) from vllm.outputs import RequestOutput from vllm.sequence import Logprob +from vllm.tracing import extract_trace_context from vllm.utils import random_uuid logger = init_logger(__name__) @@ -269,7 +268,7 @@ async def create_chat_completion( if image_data is not None: inputs["multi_modal_data"] = image_data - trace_context = TraceContextTextMapPropagator().extract( + trace_context = extract_trace_context( raw_request.headers if raw_request else {}) result_generator = self.engine.generate( inputs, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 46d03e4dd722..c10377091658 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -5,8 +5,6 @@ from typing import Tuple from fastapi import Request -from opentelemetry.trace.propagation.tracecontext import ( - TraceContextTextMapPropagator) from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -26,6 +24,7 @@ get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput from vllm.sequence import Logprob +from vllm.tracing import extract_trace_context from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -127,8 +126,7 @@ async def create_completion(self, request: CompletionRequest, truncate_prompt_tokens) prompt_ids, prompt_text = prompt_formats - trace_context = TraceContextTextMapPropagator().extract( - raw_request.headers) + trace_context = extract_trace_context(raw_request.headers) generator = self.engine.generate( { "prompt": prompt_text, diff --git a/vllm/sequence.py b/vllm/sequence.py index 07862c873d8c..a781c16a08c6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,13 +6,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch -from opentelemetry.context.context import Context from vllm.block import LogicalTokenBlock from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams +from vllm.tracing import Context if TYPE_CHECKING: from vllm.multimodal import MultiModalData diff --git a/vllm/tracing.py b/vllm/tracing.py index 68fed611a79a..47b634d91511 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -1,12 +1,36 @@ -import opentelemetry.semconv.ai -from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter) -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.trace import set_tracer_provider +from typing import Mapping, Optional +otel_installed = False +try: + from opentelemetry.context.context import Context + from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter) + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.semconv.ai import SpanAttributes as BaseSpanAttributes + from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator) + otel_installed = True +except ImportError: + + class Context: # type: ignore + pass + + class BaseSpanAttributes: # type: ignore + pass + + class SpanKind: # type: ignore + pass + + class Tracer: # type: ignore + pass + + +def init_tracer(instrumenting_module_name: str) -> Optional[Tracer]: + if not otel_installed: + return None -def init_tracer(instrumenting_module_name): trace_provider = TracerProvider() # The endpoint of OTLPSpanExporter is set from envvars: @@ -19,7 +43,14 @@ def init_tracer(instrumenting_module_name): return tracer -class SpanAttributes(opentelemetry.semconv.ai.SpanAttributes): +def extract_trace_context(headers: Mapping[str, str]) -> Optional[Context]: + if otel_installed: + return TraceContextTextMapPropagator().extract(headers) + else: + return None + + +class SpanAttributes(BaseSpanAttributes): # The following span attribute names are added here because they are missing # from the Semantic Conventions for LLM. LLM_REQUEST_ID = "gen_ai.request.id" From 5ebb1f3ece30439de7f1bde0a307dc2146aa5c2e Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Mon, 3 Jun 2024 18:55:20 +0300 Subject: [PATCH 05/23] Add --otlp_endpoint flag --- vllm/config.py | 13 +++++++++++++ vllm/engine/arg_utils.py | 40 +++++++++++++++++++++++++++------------ vllm/engine/llm_engine.py | 23 ++++++++++++++-------- vllm/tracing.py | 17 ++++++++++------- 4 files changed, 66 insertions(+), 27 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c0d294ce942e..2435295bbfe7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -10,6 +10,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry +from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu) @@ -1362,6 +1363,17 @@ def __post_init__(self): f"must be one of {valid_guided_backends}") +@dataclass +class ObservabilityConfig: + """Configuration for observability.""" + otlp_endpoint: Optional[str] = None + + def __post_init__(self): + if not is_otel_installed() and self.otlp_endpoint is not None: + raise ValueError("OpenTelemetry packages must be installed before " + "configuring 'otlp_endpoint'") + + @dataclass(frozen=True) class EngineConfig: """Dataclass which contains all engine-related configuration. This @@ -1378,6 +1390,7 @@ class EngineConfig: vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] decoding_config: Optional[DecodingConfig] + observability_config: Optional[ObservabilityConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ba53b5c86fa7..6617224dab43 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,8 +7,9 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TokenizerPoolConfig, VisionLanguageConfig) + ObservabilityConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, TokenizerPoolConfig, + VisionLanguageConfig) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import str_to_int_tuple @@ -101,6 +102,8 @@ class EngineArgs: qlora_adapter_name_or_path: Optional[str] = None + otlp_endpoint: Optional[str] = None + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @@ -598,6 +601,13 @@ def add_cli_args( type=str, default=None, help='Name or path of the QLoRA adapter.') + + parser.add_argument( + '--otlp-endpoint', + type=str, + default=None, + help='Target URL to which OpenTelemetry traces will be sent.') + return parser @classmethod @@ -756,6 +766,9 @@ def create_engine_config(self, ) -> EngineConfig: decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) + observability_config = ObservabilityConfig( + otlp_endpoint=self.otlp_endpoint) + if (model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled and not scheduler_config.use_v2_block_manager): @@ -763,16 +776,19 @@ def create_engine_config(self, ) -> EngineConfig: "Chunked prefill is not supported with sliding window. " "Set --disable-sliding-window to disable sliding window.") - return EngineConfig(model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - vision_language_config=vision_language_config, - speculative_config=speculative_config, - load_config=load_config, - decoding_config=decoding_config) + return EngineConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config, + observability_config=observability_config, + ) @dataclass diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7b75f5a03fd1..984d3840b8f2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -7,8 +7,8 @@ from transformers import GenerationConfig, PreTrainedTokenizer from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, - LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, + LoRAConfig, ModelConfig, ObservabilityConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) @@ -43,8 +43,6 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 -tracer = init_tracer("vllm.llm_engine") - def _load_generation_config_dict(model_config: ModelConfig): try: @@ -157,6 +155,7 @@ def __init__( vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], decoding_config: Optional[DecodingConfig], + observability_config: Optional[ObservabilityConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -171,9 +170,9 @@ def __init__( "disable_custom_all_reduce=%s, quantization=%s, " "enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, seed=%d, served_model_name=%s)", + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s)", VLLM_VERSION, - model_config.model, speculative_config, model_config.tokenizer, model_config.skip_tokenizer_init, @@ -195,6 +194,7 @@ def __init__( model_config.quantization_param_path, device_config.device, decoding_config, + observability_config, model_config.seed, model_config.served_model_name, ) @@ -210,6 +210,8 @@ def __init__( self.speculative_config = speculative_config self.load_config = load_config self.decoding_config = decoding_config or DecodingConfig() + self.observability_config = observability_config or ObservabilityConfig( + ) self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: @@ -291,6 +293,11 @@ def __init__( max_model_len=self.model_config.max_model_len) self.stat_logger.info("cache_config", self.cache_config) + self.tracer = None + if self.observability_config.otlp_endpoint: + self.tracer = init_tracer("vllm.llm_engine", + self.observability_config.otlp_endpoint) + # Create sequence output processor, e.g. for beam search or # speculative decoding. self.output_processor = ( @@ -994,13 +1001,13 @@ def check_health(self) -> None: self.model_executor.check_health() def create_trace_span(self, seq_group: SequenceGroup, now: float) -> None: - if tracer is None: + if self.tracer is None: return if seq_group.trace_context is None: return arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) - with tracer.start_as_current_span( + with self.tracer.start_as_current_span( "llm_request", kind=SpanKind.SERVER, context=seq_group.trace_context, diff --git a/vllm/tracing.py b/vllm/tracing.py index 47b634d91511..613ac220dd8a 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -1,6 +1,6 @@ from typing import Mapping, Optional -otel_installed = False +_is_otel_installed = False try: from opentelemetry.context.context import Context from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( @@ -11,7 +11,7 @@ from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider from opentelemetry.trace.propagation.tracecontext import ( TraceContextTextMapPropagator) - otel_installed = True + _is_otel_installed = True except ImportError: class Context: # type: ignore @@ -27,16 +27,19 @@ class Tracer: # type: ignore pass -def init_tracer(instrumenting_module_name: str) -> Optional[Tracer]: - if not otel_installed: - return None +def is_otel_installed() -> bool: + return _is_otel_installed + +def init_tracer(instrumenting_module_name: str, + otlp_endpoint: str) -> Optional[Tracer]: trace_provider = TracerProvider() # The endpoint of OTLPSpanExporter is set from envvars: # OTEL_EXPORTER_OTLP_ENDPOINT # OTEL_EXPORTER_OTLP_TRACES_ENDPOINT - trace_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter())) + trace_provider.add_span_processor( + BatchSpanProcessor(OTLPSpanExporter(endpoint=otlp_endpoint))) set_tracer_provider(trace_provider) tracer = trace_provider.get_tracer(instrumenting_module_name) @@ -44,7 +47,7 @@ def init_tracer(instrumenting_module_name: str) -> Optional[Tracer]: def extract_trace_context(headers: Mapping[str, str]) -> Optional[Context]: - if otel_installed: + if is_otel_installed(): return TraceContextTextMapPropagator().extract(headers) else: return None From dcdac44bc10fb2cb8d8fb9b4607013f8aebde3f5 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Tue, 4 Jun 2024 14:27:59 +0300 Subject: [PATCH 06/23] Write a one-time warning log for requests with a tracing header when tracing is disabled --- vllm/entrypoints/openai/serving_chat.py | 6 +++++- vllm/entrypoints/openai/serving_completion.py | 7 ++++++- vllm/tracing.py | 15 +++++++++++++++ vllm/utils.py | 12 ++++++++++++ 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index dd274ac9721b..210ffb22998d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -31,7 +31,8 @@ get_full_image_text_prompt) from vllm.outputs import RequestOutput from vllm.sequence import Logprob -from vllm.tracing import extract_trace_context +from vllm.tracing import (contains_trace_context, extract_trace_context, + log_tracing_disabled_warning) from vllm.utils import random_uuid logger = init_logger(__name__) @@ -270,6 +271,9 @@ async def create_chat_completion( trace_context = extract_trace_context( raw_request.headers if raw_request else {}) + if self.engine.engine.tracer is None and contains_trace_context( + raw_request.headers if raw_request else {}): + log_tracing_disabled_warning() result_generator = self.engine.generate( inputs, sampling_params, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c10377091658..d754b3a28eea 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -24,7 +24,8 @@ get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput from vllm.sequence import Logprob -from vllm.tracing import extract_trace_context +from vllm.tracing import (contains_trace_context, extract_trace_context, + log_tracing_disabled_warning) from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -127,6 +128,10 @@ async def create_completion(self, request: CompletionRequest, prompt_ids, prompt_text = prompt_formats trace_context = extract_trace_context(raw_request.headers) + if self.engine.engine.tracer is None and contains_trace_context( + raw_request.headers): + log_tracing_disabled_warning() + generator = self.engine.generate( { "prompt": prompt_text, diff --git a/vllm/tracing.py b/vllm/tracing.py index 613ac220dd8a..8a2012502429 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -1,5 +1,10 @@ from typing import Mapping, Optional +from vllm.logger import init_logger +from vllm.utils import run_once + +logger = init_logger(__name__) + _is_otel_installed = False try: from opentelemetry.context.context import Context @@ -63,3 +68,13 @@ class SpanAttributes(BaseSpanAttributes): LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue" LLM_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" LLM_LATENCY_E2E = "gen_ai.latency.e2e" + + +def contains_trace_context(headers: Mapping[str, str]) -> bool: + return "traceparent" in headers or "tracestate" in headers + + +@run_once +def log_tracing_disabled_warning() -> None: + logger.warning( + "Received a request with trace context but tracing is disabled") diff --git a/vllm/utils.py b/vllm/utils.py index 9b39ca77a980..bcc503427478 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -736,3 +736,15 @@ def cuda_device_count_stateless() -> int: # after https://github.com/pytorch/pytorch/pull/122815 is released. return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + + +#From: https://stackoverflow.com/a/4104188/2749989 +def run_once(f): + + def wrapper(*args, **kwargs) -> Any: + if not wrapper.has_run: # type: ignore[attr-defined] + wrapper.has_run = True # type: ignore[attr-defined] + return f(*args, **kwargs) + + wrapper.has_run = False # type: ignore[attr-defined] + return wrapper From aa972bbdf7244392c8407cd49584d6286730f082 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Tue, 4 Jun 2024 16:31:46 +0300 Subject: [PATCH 07/23] Relax condition This should allow tracing in offline inference mode. --- vllm/engine/llm_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 984d3840b8f2..03a7b514d94e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1001,9 +1001,7 @@ def check_health(self) -> None: self.model_executor.check_health() def create_trace_span(self, seq_group: SequenceGroup, now: float) -> None: - if self.tracer is None: - return - if seq_group.trace_context is None: + if self.tracer is None or seq_group.sampling_params is None: return arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) From 5b57958684a568aee00b8cbc0fb25c0e422fa03f Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Wed, 5 Jun 2024 13:04:12 +0300 Subject: [PATCH 08/23] Decouple tracing for log stats --- vllm/engine/async_llm_engine.py | 3 +++ vllm/engine/llm_engine.py | 20 +++++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 034403d1506c..7b005aa93898 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -245,6 +245,9 @@ async def step_async( # Log stats. self.do_log_stats(scheduler_outputs, output) + # Tracing + self.do_tracing(scheduler_outputs) + if not request_outputs: # Stop the execute model loop in parallel workers until there are # more requests to process. This avoids waiting indefinitely in diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 03a7b514d94e..e0456b183eb2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -804,6 +804,9 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Log stats. self.do_log_stats(scheduler_outputs, output) + # Tracing + self.do_tracing(scheduler_outputs) + if not request_outputs: # Stop the execute model loop in parallel workers until there are # more requests to process. This avoids waiting indefinitely in @@ -938,8 +941,6 @@ def _get_stats( for seq in seq_group.get_finished_seqs() ]) - self.create_trace_span(seq_group, now) - # Number of generation tokens. # num_batched_tokens equals the number of prompt_tokens plus the # number of decode_tokens in a single iteration. So, @@ -1000,7 +1001,16 @@ def list_loras(self) -> Set[int]: def check_health(self) -> None: self.model_executor.check_health() - def create_trace_span(self, seq_group: SequenceGroup, now: float) -> None: + def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None: + if self.tracer is None: + return + + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + seq_group = scheduled_seq_group.seq_group + if seq_group.is_finished(): + self.create_trace_span(seq_group) + + def create_trace_span(self, seq_group: SequenceGroup) -> None: if self.tracer is None or seq_group.sampling_params is None: return arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) @@ -1012,7 +1022,7 @@ def create_trace_span(self, seq_group: SequenceGroup, now: float) -> None: start_time=arrival_time_nano_seconds) as seq_span: metrics = seq_group.metrics ttft = metrics.first_token_time - metrics.arrival_time - e2e_time = now - seq_group.metrics.arrival_time + e2e_time = metrics.finished_time - metrics.arrival_time # attribute names are based on # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL, @@ -1040,7 +1050,7 @@ def create_trace_span(self, seq_group: SequenceGroup, now: float) -> None: for seq in seq_group.get_finished_seqs() ])) seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE, - seq_group.metrics.time_in_queue) + metrics.time_in_queue) seq_span.set_attribute( SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) From c7f10057cec1341a93b82ebabe10941bad439c76 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Wed, 5 Jun 2024 14:53:21 +0300 Subject: [PATCH 09/23] Rename --otlp-endpoint -> --otlp-traces-endpoint --- vllm/config.py | 6 +++--- vllm/engine/arg_utils.py | 6 +++--- vllm/engine/llm_engine.py | 7 ++++--- vllm/tracing.py | 4 ++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 2435295bbfe7..3b656dd48bdd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1366,12 +1366,12 @@ def __post_init__(self): @dataclass class ObservabilityConfig: """Configuration for observability.""" - otlp_endpoint: Optional[str] = None + otlp_traces_endpoint: Optional[str] = None def __post_init__(self): - if not is_otel_installed() and self.otlp_endpoint is not None: + if not is_otel_installed() and self.otlp_traces_endpoint is not None: raise ValueError("OpenTelemetry packages must be installed before " - "configuring 'otlp_endpoint'") + "configuring 'otlp_traces_endpoint'") @dataclass(frozen=True) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6617224dab43..b6548cc70887 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -102,7 +102,7 @@ class EngineArgs: qlora_adapter_name_or_path: Optional[str] = None - otlp_endpoint: Optional[str] = None + otlp_traces_endpoint: Optional[str] = None def __post_init__(self): if self.tokenizer is None: @@ -603,7 +603,7 @@ def add_cli_args( help='Name or path of the QLoRA adapter.') parser.add_argument( - '--otlp-endpoint', + '--otlp-traces-endpoint', type=str, default=None, help='Target URL to which OpenTelemetry traces will be sent.') @@ -767,7 +767,7 @@ def create_engine_config(self, ) -> EngineConfig: guided_decoding_backend=self.guided_decoding_backend) observability_config = ObservabilityConfig( - otlp_endpoint=self.otlp_endpoint) + otlp_traces_endpoint=self.otlp_traces_endpoint) if (model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e0456b183eb2..085d26507bd7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -294,9 +294,10 @@ def __init__( self.stat_logger.info("cache_config", self.cache_config) self.tracer = None - if self.observability_config.otlp_endpoint: - self.tracer = init_tracer("vllm.llm_engine", - self.observability_config.otlp_endpoint) + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) # Create sequence output processor, e.g. for beam search or # speculative decoding. diff --git a/vllm/tracing.py b/vllm/tracing.py index 8a2012502429..62b563b4c445 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -37,14 +37,14 @@ def is_otel_installed() -> bool: def init_tracer(instrumenting_module_name: str, - otlp_endpoint: str) -> Optional[Tracer]: + otlp_traces_endpoint: str) -> Optional[Tracer]: trace_provider = TracerProvider() # The endpoint of OTLPSpanExporter is set from envvars: # OTEL_EXPORTER_OTLP_ENDPOINT # OTEL_EXPORTER_OTLP_TRACES_ENDPOINT trace_provider.add_span_processor( - BatchSpanProcessor(OTLPSpanExporter(endpoint=otlp_endpoint))) + BatchSpanProcessor(OTLPSpanExporter(endpoint=otlp_traces_endpoint))) set_tracer_provider(trace_provider) tracer = trace_provider.get_tracer(instrumenting_module_name) From ca49ef324a304a17b59a6f0c28ae6018fbf0a290 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Wed, 5 Jun 2024 15:31:17 +0300 Subject: [PATCH 10/23] Update Otel.md --- examples/production_monitoring/Otel.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/production_monitoring/Otel.md b/examples/production_monitoring/Otel.md index 9838d1fc5444..70439d35c5ed 100644 --- a/examples/production_monitoring/Otel.md +++ b/examples/production_monitoring/Otel.md @@ -1,5 +1,14 @@ # Setup OpenTelemetry POC +1. Install OpenTelemetry packages: + ``` + pip install \ + opentelemetry-sdk \ + opentelemetry-api \ + opentelemetry-exporter-otlp \ + opentelemetry-semantic-conventions-ai + ``` + 1. Start Jaeger in a docker container: ``` # From: https://www.jaegertracing.io/docs/1.57/getting-started/ @@ -26,7 +35,7 @@ Then set vLLM's service name for OpenTelemetry and run vLLM: ``` export OTEL_SERVICE_NAME="vllm-server" - python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" + python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_ENDPOINT/v1/traces" ``` 1. In a new shell, send requests with trace context from a dummy client From f69936e27ecb05db3288357e5945729232eba2cc Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Wed, 5 Jun 2024 17:45:25 +0300 Subject: [PATCH 11/23] Support grpc exporter and make it the default --- vllm/tracing.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/vllm/tracing.py b/vllm/tracing.py index 62b563b4c445..d63f5e38a544 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -1,3 +1,4 @@ +import os from typing import Mapping, Optional from vllm.logger import init_logger @@ -8,8 +9,8 @@ _is_otel_installed = False try: from opentelemetry.context.context import Context - from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter) + from opentelemetry.sdk.environment_variables import ( + OTEL_EXPORTER_OTLP_TRACES_PROTOCOL) from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.ai import SpanAttributes as BaseSpanAttributes @@ -40,17 +41,29 @@ def init_tracer(instrumenting_module_name: str, otlp_traces_endpoint: str) -> Optional[Tracer]: trace_provider = TracerProvider() - # The endpoint of OTLPSpanExporter is set from envvars: - # OTEL_EXPORTER_OTLP_ENDPOINT - # OTEL_EXPORTER_OTLP_TRACES_ENDPOINT - trace_provider.add_span_processor( - BatchSpanProcessor(OTLPSpanExporter(endpoint=otlp_traces_endpoint))) + span_exporter = get_span_exporter(otlp_traces_endpoint) + trace_provider.add_span_processor(BatchSpanProcessor(span_exporter)) set_tracer_provider(trace_provider) tracer = trace_provider.get_tracer(instrumenting_module_name) return tracer +def get_span_exporter(endpoint): + protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc") + if protocol == "grpc": + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter) + elif protocol == "http/protobuf": + from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter) + else: + raise ValueError( + f"Unsupported OTLP protocol '{protocol}' is configured") + + return OTLPSpanExporter(endpoint=endpoint) + + def extract_trace_context(headers: Mapping[str, str]) -> Optional[Context]: if is_otel_installed(): return TraceContextTextMapPropagator().extract(headers) From a144cbf080bbaba9776f234c641aae1fb1580ab2 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Wed, 5 Jun 2024 18:47:12 +0300 Subject: [PATCH 12/23] Update Otel.md --- examples/production_monitoring/Otel.md | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/production_monitoring/Otel.md b/examples/production_monitoring/Otel.md index 70439d35c5ed..b2bd96a81740 100644 --- a/examples/production_monitoring/Otel.md +++ b/examples/production_monitoring/Otel.md @@ -30,18 +30,20 @@ 1. In a new shell, export Jaeger IP: ``` export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) - export OTEL_EXPORTER_OTLP_ENDPOINT=http://$JAEGER_IP:4318 + export OTEL_EXPORTER_OTLP_ENDPOINT=grpc://$JAEGER_IP:4317 ``` - Then set vLLM's service name for OpenTelemetry and run vLLM: + Then set vLLM's service name for OpenTelemetry, enable insecure connections to Jaeger and run vLLM: ``` export OTEL_SERVICE_NAME="vllm-server" - python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_ENDPOINT/v1/traces" + export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true + python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_ENDPOINT" ``` 1. In a new shell, send requests with trace context from a dummy client ``` export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) - export OTEL_EXPORTER_OTLP_ENDPOINT=http://$JAEGER_IP:4318 + export OTEL_EXPORTER_OTLP_ENDPOINT=grpc://$JAEGER_IP:4317 + export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true export OTEL_SERVICE_NAME="client-service" python dummy_client.py ``` @@ -54,12 +56,15 @@ 1. Clicking on a trace will show its spans and their tags. In this demo, each trace has 2 spans. One from the dummy client containing the prompt text and one from vLLM containing metadata about the request. ![Spans details](https://i.imgur.com/OPf6CBL.png) +## Exporter Protocol +OpenTelemetry supports either `grpc` or `http/protobuf` as the transport protocol for trace data in the exporter. +By default, `grpc` is used. To set `http/protobuf` as the protocol, configure the `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` environment variable as follows: -## Disabling tracing -OpenTelemetry tracing can be disabled by setting the environment variable: -``` -export OTEL_SDK_DISABLED=true -``` + ``` + export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf + export OTEL_EXPORTER_OTLP_ENDPOINT=http://$JAEGER_IP:4318/v1/traces + python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_ENDPOINT" + ``` ## Instrumentation of FastAPI OpenTelemetry allows automatic instrumentation of FastAPI. From 306c80520112914c676da5215bf93a0e6d09899b Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Wed, 5 Jun 2024 18:59:12 +0300 Subject: [PATCH 13/23] Rename OTEL_EXPORTER_OTLP_ENDPOINT -> OTEL_EXPORTER_OTLP_TRACES_ENDPOINT --- examples/production_monitoring/Otel.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/production_monitoring/Otel.md b/examples/production_monitoring/Otel.md index b2bd96a81740..1449442273c7 100644 --- a/examples/production_monitoring/Otel.md +++ b/examples/production_monitoring/Otel.md @@ -30,19 +30,19 @@ 1. In a new shell, export Jaeger IP: ``` export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) - export OTEL_EXPORTER_OTLP_ENDPOINT=grpc://$JAEGER_IP:4317 + export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317 ``` Then set vLLM's service name for OpenTelemetry, enable insecure connections to Jaeger and run vLLM: ``` export OTEL_SERVICE_NAME="vllm-server" export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true - python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_ENDPOINT" + python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" ``` 1. In a new shell, send requests with trace context from a dummy client ``` export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) - export OTEL_EXPORTER_OTLP_ENDPOINT=grpc://$JAEGER_IP:4317 + export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317 export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true export OTEL_SERVICE_NAME="client-service" python dummy_client.py @@ -59,12 +59,11 @@ ## Exporter Protocol OpenTelemetry supports either `grpc` or `http/protobuf` as the transport protocol for trace data in the exporter. By default, `grpc` is used. To set `http/protobuf` as the protocol, configure the `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` environment variable as follows: - - ``` - export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf - export OTEL_EXPORTER_OTLP_ENDPOINT=http://$JAEGER_IP:4318/v1/traces - python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_ENDPOINT" - ``` +``` +export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf +export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://$JAEGER_IP:4318/v1/traces +python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" +``` ## Instrumentation of FastAPI OpenTelemetry allows automatic instrumentation of FastAPI. From 70e58ad1f8bbb4dc7522b2e1feb452a93372921f Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Wed, 5 Jun 2024 19:00:04 +0300 Subject: [PATCH 14/23] Change exporter in dummy_client to grpc --- examples/production_monitoring/dummy_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/production_monitoring/dummy_client.py b/examples/production_monitoring/dummy_client.py index 2c1c703d1504..b1a2b3c3c4aa 100644 --- a/examples/production_monitoring/dummy_client.py +++ b/examples/production_monitoring/dummy_client.py @@ -1,5 +1,5 @@ import requests -from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( OTLPSpanExporter) from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import (BatchSpanProcessor, From c97d72f7f646d57a21cc4d27965c0f08f48f6cc6 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Mon, 10 Jun 2024 12:05:53 +0300 Subject: [PATCH 15/23] Add --otlp-traces-endpoint flag to benchmark_latency.py --- benchmarks/benchmark_latency.py | 48 +++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 9937f8333fb7..e3bd791ebdf3 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -20,26 +20,29 @@ def main(args: argparse.Namespace): # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. - llm = LLM(model=args.model, - speculative_model=args.speculative_model, - num_speculative_tokens=args.num_speculative_tokens, - tokenizer=args.tokenizer, - quantization=args.quantization, - tensor_parallel_size=args.tensor_parallel_size, - trust_remote_code=args.trust_remote_code, - dtype=args.dtype, - enforce_eager=args.enforce_eager, - kv_cache_dtype=args.kv_cache_dtype, - quantization_param_path=args.quantization_param_path, - device=args.device, - ray_workers_use_nsight=args.ray_workers_use_nsight, - use_v2_block_manager=args.use_v2_block_manager, - enable_chunked_prefill=args.enable_chunked_prefill, - download_dir=args.download_dir, - block_size=args.block_size, - gpu_memory_utilization=args.gpu_memory_utilization, - load_format=args.load_format, - distributed_executor_backend=args.distributed_executor_backend) + llm = LLM( + model=args.model, + speculative_model=args.speculative_model, + num_speculative_tokens=args.num_speculative_tokens, + tokenizer=args.tokenizer, + quantization=args.quantization, + tensor_parallel_size=args.tensor_parallel_size, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + enforce_eager=args.enforce_eager, + kv_cache_dtype=args.kv_cache_dtype, + quantization_param_path=args.quantization_param_path, + device=args.device, + ray_workers_use_nsight=args.ray_workers_use_nsight, + use_v2_block_manager=args.use_v2_block_manager, + enable_chunked_prefill=args.enable_chunked_prefill, + download_dir=args.download_dir, + block_size=args.block_size, + gpu_memory_utilization=args.gpu_memory_utilization, + load_format=args.load_format, + distributed_executor_backend=args.distributed_executor_backend, + otlp_traces_endpoint=args.otlp_traces_endpoint, + ) sampling_params = SamplingParams( n=args.n, @@ -254,5 +257,10 @@ def run_to_completion(profile_dir: Optional[str] = None): help='Backend to use for distributed serving. When more than 1 GPU ' 'is used, will be automatically set to "ray" if installed ' 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--otlp-traces-endpoint', + type=str, + default=None, + help='Target URL to which OpenTelemetry traces will be sent.') args = parser.parse_args() main(args) From 964e613efa0c518dbf8f006537539171acd9518b Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Tue, 11 Jun 2024 11:09:25 +0300 Subject: [PATCH 16/23] Add e2e test for tracing --- tests/tracing/__init__.py | 0 tests/tracing/test_tracing.py | 113 ++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 tests/tracing/__init__.py create mode 100644 tests/tracing/test_tracing.py diff --git a/tests/tracing/__init__.py b/tests/tracing/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/tracing/test_tracing.py b/tests/tracing/test_tracing.py new file mode 100644 index 000000000000..ef46084444aa --- /dev/null +++ b/tests/tracing/test_tracing.py @@ -0,0 +1,113 @@ +import os +import threading +from concurrent import futures +from typing import Iterable + +import grpc +import pytest +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( + ExportTraceServiceResponse) +from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( + TraceServiceServicer, add_TraceServiceServicer_to_server) +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue +from opentelemetry.sdk.environment_variables import ( + OTEL_EXPORTER_OTLP_TRACES_INSECURE) + +from vllm import LLM, SamplingParams +from vllm.tracing import SpanAttributes + +FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" + + +def decode_value(value: AnyValue): + field_decoders = { + "bool_value": (lambda v: v.bool_value), + "string_value": (lambda v: v.string_value), + "int_value": (lambda v: v.int_value), + "double_value": (lambda v: v.double_value), + "array_value": + (lambda v: [decode_value(item) for item in v.array_value.values]), + } + for field, decoder in field_decoders.items(): + if value.HasField(field): + return decoder(value) + raise ValueError(f"Couldn't decode value: {value}") + + +def decode_attributes(attributes: Iterable[KeyValue]): + return {kv.key: decode_value(kv.value) for kv in attributes} + + +class FakeTraceService(TraceServiceServicer): + + def __init__(self): + self.request = None + self.evt = threading.Event() + + def Export(self, request, context): + self.request = request + self.evt.set() + return ExportTraceServiceResponse() + + +@pytest.fixture +def trace_service(): + """Fixture to set up a fake gRPC trace service""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + service = FakeTraceService() + add_TraceServiceServicer_to_server(service, server) + server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS) + server.start() + + yield service + + server.stop(None) + + +def test_traces(trace_service): + os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true" + + sampling_params = SamplingParams(temperature=0.01, + top_p=0.1, + max_tokens=256) + model = "facebook/opt-125m" + llm = LLM( + model=model, + otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, + ) + prompts = ["This is a short prompt"] + outputs = llm.generate(prompts, sampling_params=sampling_params) + + timeout = 5 + if not trace_service.evt.wait(timeout): + raise TimeoutError( + f"The fake trace service didn't receive a trace within " + f"the {timeout} seconds timeout") + + attributes = decode_attributes(trace_service.request.resource_spans[0]. + scope_spans[0].spans[0].attributes) + assert attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) == model + assert attributes.get( + SpanAttributes.LLM_REQUEST_ID) == outputs[0].request_id + assert attributes.get( + SpanAttributes.LLM_REQUEST_TEMPERATURE) == sampling_params.temperature + assert attributes.get( + SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p + assert attributes.get( + SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens + assert attributes.get( + SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of + assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n + assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len( + outputs[0].prompt_token_ids) + completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) + assert attributes.get( + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS) == completion_tokens + metrics = outputs[0].metrics + assert attributes.get( + SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue + ttft = metrics.first_token_time - metrics.arrival_time + assert attributes.get( + SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft + e2e_time = metrics.finished_time - metrics.arrival_time + assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time From edfdf55e6e797eced83a8e655eba0054b3e7d97e Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Tue, 11 Jun 2024 11:27:44 +0300 Subject: [PATCH 17/23] Add tracing to CI --- .buildkite/test-pipeline.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6439a315e327..ecca3056ac55 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -159,6 +159,15 @@ steps: #mirror_hardwares: [amd] command: pytest -v -s quantization +- label: Tracing Test + commands: + - "pip install \ + opentelemetry-sdk \ + opentelemetry-api \ + opentelemetry-exporter-otlp \ + opentelemetry-semantic-conventions-ai" + - pytest -v -s tracing + - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" mirror_hardwares: [amd] From b7d6f7e76862424e26545cc1652bb1f8a19fdc01 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Thu, 13 Jun 2024 13:47:05 +0300 Subject: [PATCH 18/23] Add is_tracing_enabled() When --engine-use-ray is set, the server cannot access engine.tracer. Instead, we expose whether tracing is enabled through is_tracing_enabled(). --- vllm/engine/async_llm_engine.py | 7 +++++++ vllm/engine/llm_engine.py | 3 +++ vllm/entrypoints/openai/serving_chat.py | 6 ++++-- vllm/entrypoints/openai/serving_completion.py | 5 +++-- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7b005aa93898..b81474fda56b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -853,3 +853,10 @@ async def check_health(self) -> None: else: await self.engine.check_health_async() logger.debug("Health check took %fs", time.perf_counter() - t) + + async def is_tracing_enabled(self) -> bool: + if self.engine_use_ray: + return await self.engine.is_tracing_enabled.remote( # type: ignore + ) + else: + return self.engine.is_tracing_enabled() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 085d26507bd7..1a383c5ba713 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1002,6 +1002,9 @@ def list_loras(self) -> Set[int]: def check_health(self) -> None: self.model_executor.check_health() + def is_tracing_enabled(self) -> bool: + return self.tracer is not None + def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None: if self.tracer is None: return diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 210ffb22998d..0ab20432d938 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -271,15 +271,17 @@ async def create_chat_completion( trace_context = extract_trace_context( raw_request.headers if raw_request else {}) - if self.engine.engine.tracer is None and contains_trace_context( + is_tracing_enabled = await self.engine.is_tracing_enabled() + if not is_tracing_enabled and contains_trace_context( raw_request.headers if raw_request else {}): log_tracing_disabled_warning() + result_generator = self.engine.generate( inputs, sampling_params, request_id, lora_request, - trace_context=trace_context, + trace_headers=trace_context, ) # Streaming response if request.stream: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index d754b3a28eea..a9f35fb31707 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -24,7 +24,7 @@ get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput from vllm.sequence import Logprob -from vllm.tracing import (contains_trace_context, extract_trace_context, +from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.utils import merge_async_iterators, random_uuid @@ -128,7 +128,8 @@ async def create_completion(self, request: CompletionRequest, prompt_ids, prompt_text = prompt_formats trace_context = extract_trace_context(raw_request.headers) - if self.engine.engine.tracer is None and contains_trace_context( + is_tracing_enabled = await self.engine.is_tracing_enabled() + if not is_tracing_enabled and contains_trace_context( raw_request.headers): log_tracing_disabled_warning() From b5d2735ee369e750ccea20bee9d877aa16e1317b Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Thu, 13 Jun 2024 13:39:28 +0300 Subject: [PATCH 19/23] Pass trace headers instead of trace context When --engine-use-ray is set, the server cannot pass otel Context as its immutable. --- vllm/engine/async_llm_engine.py | 25 +++++++++---------- vllm/engine/llm_engine.py | 23 +++++++++-------- vllm/entrypoints/openai/serving_chat.py | 13 +++++----- vllm/entrypoints/openai/serving_completion.py | 8 +++--- vllm/sequence.py | 7 +++--- vllm/tracing.py | 11 ++++++-- 6 files changed, 49 insertions(+), 38 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index b81474fda56b..4cb753cadb23 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -19,7 +19,6 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.tracing import Context from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -289,7 +288,7 @@ async def add_request_async( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - trace_context: Optional[Context] = None, + trace_headers: Optional[Dict[str, str]] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -306,7 +305,7 @@ async def add_request_async( params=params, arrival_time=arrival_time, lora_request=lora_request, - trace_context=trace_context, + trace_headers=trace_headers, ) async def check_health_async(self) -> None: @@ -551,7 +550,7 @@ async def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - trace_context: Optional[Context] = None, + trace_headers: Optional[Dict[str, str]] = None, ) -> AsyncStream: if self.log_requests: if isinstance(inputs, str): @@ -593,7 +592,7 @@ async def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, - trace_context=trace_context, + trace_headers=trace_headers, ) return stream @@ -604,7 +603,7 @@ async def generate( sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, - trace_context: Optional[Context] = None, + trace_headers: Optional[Dict[str, str]] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -619,7 +618,7 @@ async def generate( sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. - trace_context: OpenTelemetry trace context. + trace_headers: OpenTelemetry trace headers. Yields: The output `RequestOutput` objects from the LLMEngine @@ -673,7 +672,7 @@ async def generate( inputs, sampling_params, lora_request=lora_request, - trace_context=trace_context, + trace_headers=trace_headers, ): yield LLMEngine.validate_output(output, RequestOutput) @@ -683,7 +682,7 @@ async def encode( pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, - trace_context: Optional[Context] = None, + trace_headers: Optional[Dict[str, str]] = None, ) -> AsyncIterator[EmbeddingRequestOutput]: """Generate outputs for a request from an embedding model. @@ -698,7 +697,7 @@ async def encode( pooling_params: The pooling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. - trace_context: OpenTelemetry trace context. + trace_headers: OpenTelemetry trace headers. Yields: The output `EmbeddingRequestOutput` objects from the LLMEngine @@ -750,7 +749,7 @@ async def encode( inputs, pooling_params, lora_request=lora_request, - trace_context=trace_context, + trace_headers=trace_headers, ): yield LLMEngine.validate_output(output, EmbeddingRequestOutput) @@ -761,7 +760,7 @@ async def _process_request( params: Union[SamplingParams, PoolingParams], *, lora_request: Optional[LoRARequest] = None, - trace_context: Optional[Context] = None, + trace_headers: Optional[Dict[str, str]] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" @@ -773,7 +772,7 @@ async def _process_request( params, arrival_time=arrival_time, lora_request=lora_request, - trace_context=trace_context, + trace_headers=trace_headers, ) try: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1a383c5ba713..028a2a6227cf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,6 +1,6 @@ import time from contextlib import contextmanager -from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional +from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Set, Type, TypeVar, Union @@ -31,7 +31,8 @@ PoolerOutput, SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.tracing import Context, SpanAttributes, SpanKind, init_tracer +from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, + init_tracer) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -447,7 +448,7 @@ def _add_processed_request( params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], - trace_context: Optional[Context] = None, + trace_headers: Optional[Dict[str, str]] = None, ) -> None: # Create the sequences. block_size = self.cache_config.block_size @@ -465,7 +466,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, - trace_context=trace_context, + trace_headers=trace_headers, ) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( @@ -512,7 +513,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - trace_context: Optional[Context] = None, + trace_headers: Optional[Dict[str, str]] = None, ) -> None: """Add a request to the engine's request pool. @@ -530,7 +531,7 @@ def add_request( :class:`~vllm.PoolingParams` for pooling. arrival_time: The arrival time of the request. If None, we use the current monotonic time. - trace_context: OpenTelemetry trace context. + trace_headers: OpenTelemetry trace headers. Details: - Set arrival_time to the current time if it is None. @@ -572,7 +573,7 @@ def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, - trace_context=trace_context, + trace_headers=trace_headers, ) def _create_sequence_group_with_sampling( @@ -582,7 +583,7 @@ def _create_sequence_group_with_sampling( sampling_params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest], - trace_context: Optional[Context] = None, + trace_headers: Optional[Dict[str, str]] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -610,7 +611,7 @@ def _create_sequence_group_with_sampling( arrival_time=arrival_time, sampling_params=sampling_params, lora_request=lora_request, - trace_context=trace_context, + trace_headers=trace_headers, ) return seq_group @@ -1019,10 +1020,12 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: return arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) + trace_context = extract_trace_context(seq_group.trace_headers) + with self.tracer.start_as_current_span( "llm_request", kind=SpanKind.SERVER, - context=seq_group.trace_context, + context=trace_context, start_time=arrival_time_nano_seconds) as seq_span: metrics = seq_group.metrics ttft = metrics.first_token_time - metrics.arrival_time diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0ab20432d938..744e1d94511b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -31,7 +31,7 @@ get_full_image_text_prompt) from vllm.outputs import RequestOutput from vllm.sequence import Logprob -from vllm.tracing import (contains_trace_context, extract_trace_context, +from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.utils import random_uuid @@ -269,11 +269,12 @@ async def create_chat_completion( if image_data is not None: inputs["multi_modal_data"] = image_data - trace_context = extract_trace_context( - raw_request.headers if raw_request else {}) is_tracing_enabled = await self.engine.is_tracing_enabled() - if not is_tracing_enabled and contains_trace_context( - raw_request.headers if raw_request else {}): + trace_headers = None + if is_tracing_enabled and raw_request: + trace_headers = extract_trace_headers(raw_request.headers) + if not is_tracing_enabled and raw_request and contains_trace_headers( + raw_request.headers): log_tracing_disabled_warning() result_generator = self.engine.generate( @@ -281,7 +282,7 @@ async def create_chat_completion( sampling_params, request_id, lora_request, - trace_headers=trace_context, + trace_headers=trace_headers, ) # Streaming response if request.stream: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a9f35fb31707..c775fa6daa73 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -127,9 +127,11 @@ async def create_completion(self, request: CompletionRequest, truncate_prompt_tokens) prompt_ids, prompt_text = prompt_formats - trace_context = extract_trace_context(raw_request.headers) is_tracing_enabled = await self.engine.is_tracing_enabled() - if not is_tracing_enabled and contains_trace_context( + trace_headers = None + if is_tracing_enabled: + trace_headers = extract_trace_headers(raw_request.headers) + if not is_tracing_enabled and contains_trace_headers( raw_request.headers): log_tracing_disabled_warning() @@ -141,7 +143,7 @@ async def create_completion(self, request: CompletionRequest, sampling_params, f"{request_id}-{i}", lora_request=lora_request, - trace_context=trace_context, + trace_headers=trace_headers, ) generators.append(generator) diff --git a/vllm/sequence.py b/vllm/sequence.py index a781c16a08c6..38d3349f2ab4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -12,7 +12,6 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.tracing import Context if TYPE_CHECKING: from vllm.multimodal import MultiModalData @@ -415,7 +414,7 @@ class SequenceGroup: for an embedding model. encoder_seq: Optional, the single encoder sequence. Should be None unless you are working with an encoder/decoder model. - trace_context: OpenTelemetry trace context. + trace_headers: OpenTelemetry trace headers. """ def __init__( @@ -428,7 +427,7 @@ def __init__( embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, - trace_context: Optional[Context] = None, + trace_headers: Optional[Dict[str, str]] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -444,7 +443,7 @@ def __init__( self.embeddings = embeddings self.pooling_params = pooling_params self.encoder_seq = encoder_seq - self.trace_context = trace_context + self.trace_headers = trace_headers @property def prompt(self) -> Optional[str]: diff --git a/vllm/tracing.py b/vllm/tracing.py index d63f5e38a544..bf8d38fac303 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -4,6 +4,8 @@ from vllm.logger import init_logger from vllm.utils import run_once +TRACE_HEADERS = ["traceparent", "tracestate"] + logger = init_logger(__name__) _is_otel_installed = False @@ -71,6 +73,11 @@ def extract_trace_context(headers: Mapping[str, str]) -> Optional[Context]: return None +def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]: + + return {h: headers[h] for h in TRACE_HEADERS if h in headers} + + class SpanAttributes(BaseSpanAttributes): # The following span attribute names are added here because they are missing # from the Semantic Conventions for LLM. @@ -83,8 +90,8 @@ class SpanAttributes(BaseSpanAttributes): LLM_LATENCY_E2E = "gen_ai.latency.e2e" -def contains_trace_context(headers: Mapping[str, str]) -> bool: - return "traceparent" in headers or "tracestate" in headers +def contains_trace_headers(headers: Mapping[str, str]) -> bool: + return any(h in headers for h in TRACE_HEADERS) @run_once From 2f8b0ab491637c271a74218c4fa9d893ea3bf0ca Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Thu, 13 Jun 2024 16:11:44 +0300 Subject: [PATCH 20/23] Handle passing None to extract_trace_context --- vllm/tracing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/tracing.py b/vllm/tracing.py index bf8d38fac303..07e50aca58f0 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -66,8 +66,10 @@ def get_span_exporter(endpoint): return OTLPSpanExporter(endpoint=endpoint) -def extract_trace_context(headers: Mapping[str, str]) -> Optional[Context]: +def extract_trace_context( + headers: Optional[Mapping[str, str]]) -> Optional[Context]: if is_otel_installed(): + headers = headers or {} return TraceContextTextMapPropagator().extract(headers) else: return None From 9fa949e268805dbb411992b79f25c9a6b7fdcb59 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Mon, 17 Jun 2024 13:42:07 +0300 Subject: [PATCH 21/23] Assert otel is installed in init_tracer() --- vllm/tracing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/tracing.py b/vllm/tracing.py index 07e50aca58f0..ba6732cab68f 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -41,6 +41,8 @@ def is_otel_installed() -> bool: def init_tracer(instrumenting_module_name: str, otlp_traces_endpoint: str) -> Optional[Tracer]: + assert is_otel_installed(), ("OpenTelemetry packages must be installed " + "prior to initializing a tracer") trace_provider = TracerProvider() span_exporter = get_span_exporter(otlp_traces_endpoint) From a6d17f39288d2730860de2c161c1be0819388843 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Mon, 17 Jun 2024 15:43:53 +0300 Subject: [PATCH 22/23] Fix bug itroduced by git rebase --- vllm/engine/llm_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 028a2a6227cf..bba578593ddb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -174,6 +174,7 @@ def __init__( "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s)", VLLM_VERSION, + model_config.model, speculative_config, model_config.tokenizer, model_config.skip_tokenizer_init, From 21dba065b025c96a99944b5d3f489a91e797ba4c Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Mon, 17 Jun 2024 15:58:33 +0300 Subject: [PATCH 23/23] Fix lint error --- tests/tracing/test_tracing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/tracing/test_tracing.py b/tests/tracing/test_tracing.py index ef46084444aa..2f8f62cf2d1e 100644 --- a/tests/tracing/test_tracing.py +++ b/tests/tracing/test_tracing.py @@ -1,7 +1,7 @@ import os import threading from concurrent import futures -from typing import Iterable +from typing import Callable, Dict, Iterable, Literal import grpc import pytest @@ -18,9 +18,12 @@ FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" +FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value', + 'array_value'] + def decode_value(value: AnyValue): - field_decoders = { + field_decoders: Dict[FieldName, Callable] = { "bool_value": (lambda v: v.bool_value), "string_value": (lambda v: v.string_value), "int_value": (lambda v: v.int_value),