diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index adbd08b13f75..b0f5fe418dcf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -231,7 +231,7 @@ steps: source_file_dependencies: - vllm/ - tests/metrics - - tests/tracing + - tests/v1/tracing commands: - pytest -v -s metrics - "pip install \ diff --git a/tests/v1/tracing/test_tracing.py b/tests/v1/tracing/test_tracing.py new file mode 100644 index 000000000000..da8655f95e19 --- /dev/null +++ b/tests/v1/tracing/test_tracing.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa +# type: ignore +from __future__ import annotations + +import threading +from collections.abc import Iterable +from concurrent import futures +from typing import Callable, Generator, Literal + +import grpc +import pytest +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( + ExportTraceServiceResponse) +from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( + TraceServiceServicer, add_TraceServiceServicer_to_server) +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue +from opentelemetry.sdk.environment_variables import ( + OTEL_EXPORTER_OTLP_TRACES_INSECURE) + +from vllm import LLM, SamplingParams +from vllm.tracing import SpanAttributes + +FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" + +FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value', + 'array_value'] + + +def decode_value(value: AnyValue): + field_decoders: dict[FieldName, Callable] = { + "bool_value": (lambda v: v.bool_value), + "string_value": (lambda v: v.string_value), + "int_value": (lambda v: v.int_value), + "double_value": (lambda v: v.double_value), + "array_value": + (lambda v: [decode_value(item) for item in v.array_value.values]), + } + for field, decoder in field_decoders.items(): + if value.HasField(field): + return decoder(value) + raise ValueError(f"Couldn't decode value: {value}") + + +def decode_attributes(attributes: Iterable[KeyValue]): + return {kv.key: decode_value(kv.value) for kv in attributes} + + +class FakeTraceService(TraceServiceServicer): + + def __init__(self): + self.request = None + self.evt = threading.Event() + + def Export(self, request, context): + self.request = request + self.evt.set() + return ExportTraceServiceResponse() + + +@pytest.fixture +def trace_service() -> Generator[FakeTraceService, None, None]: + """Fixture to set up a fake gRPC trace service""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + service = FakeTraceService() + add_TraceServiceServicer_to_server(service, server) + server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS) + server.start() + + yield service + + server.stop(None) + + +def test_traces( + monkeypatch: pytest.MonkeyPatch, + trace_service: FakeTraceService, +): + with monkeypatch.context() as m: + m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") + m.setenv("VLLM_USE_V1", "1") + sampling_params = SamplingParams( + temperature=0.01, + top_p=0.1, + max_tokens=256, + ) + model = "facebook/opt-125m" + llm = LLM(model=model, + otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, + gpu_memory_utilization=0.3, + disable_log_stats=False) + prompts = ["This is a short prompt"] + outputs = llm.generate(prompts, sampling_params=sampling_params) + print(f"test_traces outputs is : {outputs}") + + timeout = 10 + if not trace_service.evt.wait(timeout): + raise TimeoutError( + f"The fake trace service didn't receive a trace within " + f"the {timeout} seconds timeout") + + request = trace_service.request + assert len(request.resource_spans) == 1, ( + f"Expected 1 resource span, " + f"but got {len(request.resource_spans)}") + assert len(request.resource_spans[0].scope_spans) == 1, ( + f"Expected 1 scope span, " + f"but got {len(request.resource_spans[0].scope_spans)}") + assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( + f"Expected 1 span, " + f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") + + attributes = decode_attributes( + request.resource_spans[0].scope_spans[0].spans[0].attributes) + # assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model + assert attributes.get( + SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE + ) == sampling_params.temperature + assert attributes.get( + SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS + ) == sampling_params.max_tokens + assert attributes.get( + SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n + assert attributes.get( + SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( + outputs[0].prompt_token_ids) + completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) + assert attributes.get( + SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens + + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) > 0 + assert attributes.get( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0 + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) > 0 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ba1543a8d5a3..9841876bc8e8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1491,12 +1491,6 @@ class EngineArgs: recommend_to_remove=False) return False - # No OTLP observability so far. - if (self.otlp_traces_endpoint or self.collect_detailed_traces): - _raise_or_fallback(feature_name="--otlp-traces-endpoint", - recommend_to_remove=False) - return False - # V1 supports N-gram, Medusa, and Eagle speculative decoding. if (self.speculative_config is not None and self.speculative_config.get("method") == "draft_model"): diff --git a/vllm/tracing.py b/vllm/tracing.py index 6a287d82be5f..7537e9901a04 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -119,6 +119,11 @@ class SpanAttributes: # forward, block/sync across workers, cpu-gpu sync time and sampling time. GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = ( "gen_ai.latency.time_in_model_execute") + GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = \ + "gen_ai.latency.time_in_model_prefill" + GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode" + GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = \ + "gen_ai.latency.time_in_model_inference" def contains_trace_headers(headers: Mapping[str, str]) -> bool: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d1a6dd73e85c..aa45f6669207 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -969,9 +969,9 @@ class Scheduler(SchedulerInterface): stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, )) - else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 5d8959a3cd3f..dec4abec519b 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -3,6 +3,7 @@ import enum import time +from collections.abc import Mapping from typing import Any, Optional, Union import msgspec @@ -66,6 +67,8 @@ class EngineCoreRequest( current_wave: int = 0 priority: int = 0 + trace_headers: Optional[Mapping[str, str]] = None + class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" @@ -111,6 +114,7 @@ class EngineCoreOutput( events: Optional[list[EngineCoreEvent]] = None kv_transfer_params: Optional[dict[str, Any]] = None + trace_headers: Optional[Mapping[str, str]] = None # The number of tokens with prefix cache hits. num_cached_tokens: int = 0 @@ -144,7 +148,7 @@ class EngineCoreOutputs( omit_defaults=True, # type: ignore[call-arg] gc=False): # type: ignore[call-arg] - #NOTE(Nick): We could consider ways to make this more compact, + # NOTE(Nick): We could consider ways to make this more compact, # e.g. columnwise layout engine_index: int = 0 diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index f57075c6fa82..e467f40f6eab 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -26,6 +26,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask +from vllm.tracing import init_tracer from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -97,6 +98,7 @@ class AsyncLLM(EngineClient): self.model_config = vllm_config.model_config self.vllm_config = vllm_config + self.observability_config = vllm_config.observability_config self.log_requests = log_requests self.log_stats = log_stats or (stat_loggers is not None) @@ -124,6 +126,11 @@ class AsyncLLM(EngineClient): # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor(self.tokenizer, log_stats=self.log_stats) + if self.observability_config.otlp_traces_endpoint is not None: + tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + self.output_processor.tracer = tracer # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_async_mp_client( @@ -603,7 +610,7 @@ class AsyncLLM(EngineClient): return self.tokenizer.get_lora_tokenizer(lora_request) async def is_tracing_enabled(self) -> bool: - return False + return self.observability_config.otlp_traces_endpoint is not None async def do_log_stats( self, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 7130f666ef19..fca5a783bc3b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -19,6 +19,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask +from vllm.tracing import init_tracer from vllm.transformers_utils.tokenizer_group import ( TokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext @@ -65,6 +66,7 @@ class LLMEngine: "Set VLLM_USE_V1=0 and file and issue on Github.") self.vllm_config = vllm_config + self.observability_config = vllm_config.observability_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -99,6 +101,11 @@ class LLMEngine: # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor(self.tokenizer, log_stats=self.log_stats) + if self.observability_config.otlp_traces_endpoint is not None: + tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + self.output_processor.tracer = tracer # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) self.engine_core = EngineCoreClient.make_client( diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2ee55b585da6..02c8c61cb909 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -11,6 +11,8 @@ import torch from vllm.outputs import (CompletionOutput, PoolingOutput, PoolingRequestOutput, RequestOutput) from vllm.sampling_params import RequestOutputKind +from vllm.tracing import (SpanAttributes, SpanKind, Tracer, + extract_trace_context) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason @@ -71,7 +73,6 @@ class RequestOutputCollector: @dataclass class OutputProcessorOutput: - request_outputs: list[Union[RequestOutput, PoolingRequestOutput]] reqs_to_abort: list[str] @@ -93,6 +94,9 @@ class RequestState: arrival_time: float, queue: Optional[RequestOutputCollector], log_stats: bool, + top_p: Optional[float] = None, + n: Optional[int] = None, + temperature: Optional[float] = None, ): self.request_id = request_id self.parent_req = parent_req @@ -105,6 +109,9 @@ class RequestState: self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.max_tokens_param = max_tokens_param + self.top_p = top_p + self.n = n + self.temperature = temperature self.is_prefilling = True self.queue = queue self.num_cached_tokens = 0 @@ -137,10 +144,16 @@ class RequestState: request=request, ) max_tokens_param = sampling_params.max_tokens + top_p = sampling_params.top_p + n = sampling_params.n + temperature = sampling_params.temperature else: logprobs_processor = None detokenizer = None max_tokens_param = None + top_p = None + n = None + temperature = None assert request.pooling_params is not None output_kind = request.pooling_params.output_kind @@ -156,6 +169,9 @@ class RequestState: logprobs_processor=logprobs_processor, detokenizer=detokenizer, max_tokens_param=max_tokens_param, + top_p=top_p, + n=n, + temperature=temperature, arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, @@ -274,16 +290,13 @@ class RequestState: class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" - def __init__( - self, - tokenizer: TokenizerGroup, - log_stats: bool, - ): + def __init__(self, tokenizer: TokenizerGroup, log_stats: bool): self.log_stats = log_stats self.tokenizer = tokenizer self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates() + self.tracer: Optional[Tracer] = None def get_num_unfinished_requests(self): return len(self.request_states) @@ -441,7 +454,9 @@ class OutputProcessor: # Track per-request stats self._update_stats_from_finished(req_state, finish_reason, iteration_stats) - + if self.tracer: + self.do_tracing(engine_core_output, req_state, + iteration_stats) self.lora_states.update_iteration_stats(iteration_stats) return OutputProcessorOutput( @@ -449,6 +464,63 @@ class OutputProcessor: reqs_to_abort=reqs_to_abort, ) + def do_tracing(self, engine_core_output: EngineCoreOutput, + req_state: RequestState, + iteration_stats: Optional[IterationStats]) -> None: + assert req_state.stats is not None + assert iteration_stats is not None + assert self.tracer is not None + + arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9) + trace_context = extract_trace_context(engine_core_output.trace_headers) + with (self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds) as span): + metrics = req_state.stats + e2e_time = iteration_stats.iteration_timestamp - \ + metrics.arrival_time + queued_time = metrics.scheduled_ts - metrics.queued_ts + prefill_time = metrics.first_token_ts - metrics.scheduled_ts + decode_time = metrics.last_token_ts - metrics.first_token_ts + inference_time = metrics.last_token_ts - metrics.scheduled_ts + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, + metrics.first_token_latency) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, + queued_time) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, + len(req_state.prompt_token_ids)) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, + metrics.num_generation_tokens) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, + prefill_time) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, + decode_time) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, + inference_time) + + # meta + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, + req_state.request_id) + if req_state.top_p: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, + req_state.top_p) + if req_state.max_tokens_param: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, + req_state.max_tokens_param) + if req_state.temperature: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, + req_state.temperature) + if req_state.n: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, + req_state.n) + def _update_stats_from_output(self, req_state: RequestState, engine_core_output: EngineCoreOutput, engine_core_timestamp: Optional[float], diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 96e89eeac556..a54a98adf56b 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -327,8 +327,6 @@ class Processor: # TODO(woosuk): Support pooling models. self._validate_lora(lora_request) self._validate_params(params, lora_request) - if trace_headers is not None: - raise ValueError("V1 does not support tracing yet.") data_parallel_size = self.vllm_config.parallel_config.data_parallel_size if data_parallel_rank is not None and not (0 <= data_parallel_rank < @@ -435,6 +433,7 @@ class Processor: cache_salt=decoder_inputs.get("cache_salt"), priority=priority, data_parallel_rank=data_parallel_rank, + trace_headers=trace_headers, ) def _validate_model_inputs(self, diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 45c32aaaaf6c..e6c344d193df 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -68,6 +68,9 @@ class RequestStateStats: first_token_ts: float = 0.0 last_token_ts: float = 0.0 + # first token latency + first_token_latency: float = 0.0 + @dataclass class FinishedRequestStats: @@ -116,6 +119,7 @@ class IterationStats: first_token_latency = self._time_since(req_stats.arrival_time) self.time_to_first_tokens_iter.append(first_token_latency) + req_stats.first_token_latency = first_token_latency req_stats.num_generation_tokens += num_new_generation_tokens diff --git a/vllm/v1/request.py b/vllm/v1/request.py index ad7477241ebb..64cce3e9efc5 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,6 +3,7 @@ import enum import time +from collections.abc import Mapping from functools import partial from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -35,6 +36,7 @@ class Request: structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, priority: int = 0, + trace_headers: Optional[Mapping[str, str]] = None, block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] = None, ) -> None: @@ -100,7 +102,8 @@ class Request: # they should also be updated simultaneously. self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) - + # trace_headers + self.trace_headers = trace_headers # State # The number of tokens with prefix cache hits. self.num_cached_tokens = -1 @@ -136,6 +139,7 @@ class Request: if request.sampling_params else None, cache_salt=request.cache_salt, priority=request.priority, + trace_headers=request.trace_headers, block_hasher=block_hasher, )