[V1] feat:add engine v1 tracing (#20372)

Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Signed-off-by: Ye Zhang <zhysishu@gmail.com>
Signed-off-by: RichardoMu <44485717+RichardoMrMu@users.noreply.github.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Mu Huai <tianbowen.tbw@antgroup.com>
Co-authored-by: Ye Zhang <zhysishu@gmail.com>
Co-authored-by: Benjamin Bartels <benjamin@bartels.dev>
Co-authored-by: simon-mo <simon.mo@hey.com>
Co-authored-by: 瑜琮 <ly186375@antfin.com>
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
RichardoMu 2025-09-12 08:10:39 +08:00 committed by GitHub
parent 2e6bc46821
commit 40b6c9122b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 253 additions and 20 deletions

View File

@ -231,7 +231,7 @@ steps:
source_file_dependencies:
- vllm/
- tests/metrics
- tests/tracing
- tests/v1/tracing
commands:
- pytest -v -s metrics
- "pip install \

View File

@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
# type: ignore
from __future__ import annotations
import threading
from collections.abc import Iterable
from concurrent import futures
from typing import Callable, Generator, Literal
import grpc
import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
ExportTraceServiceResponse)
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
TraceServiceServicer, add_TraceServiceServicer_to_server)
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_TRACES_INSECURE)
from vllm import LLM, SamplingParams
from vllm.tracing import SpanAttributes
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
'array_value']
def decode_value(value: AnyValue):
field_decoders: dict[FieldName, Callable] = {
"bool_value": (lambda v: v.bool_value),
"string_value": (lambda v: v.string_value),
"int_value": (lambda v: v.int_value),
"double_value": (lambda v: v.double_value),
"array_value":
(lambda v: [decode_value(item) for item in v.array_value.values]),
}
for field, decoder in field_decoders.items():
if value.HasField(field):
return decoder(value)
raise ValueError(f"Couldn't decode value: {value}")
def decode_attributes(attributes: Iterable[KeyValue]):
return {kv.key: decode_value(kv.value) for kv in attributes}
class FakeTraceService(TraceServiceServicer):
def __init__(self):
self.request = None
self.evt = threading.Event()
def Export(self, request, context):
self.request = request
self.evt.set()
return ExportTraceServiceResponse()
@pytest.fixture
def trace_service() -> Generator[FakeTraceService, None, None]:
"""Fixture to set up a fake gRPC trace service"""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
service = FakeTraceService()
add_TraceServiceServicer_to_server(service, server)
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
server.start()
yield service
server.stop(None)
def test_traces(
monkeypatch: pytest.MonkeyPatch,
trace_service: FakeTraceService,
):
with monkeypatch.context() as m:
m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
m.setenv("VLLM_USE_V1", "1")
sampling_params = SamplingParams(
temperature=0.01,
top_p=0.1,
max_tokens=256,
)
model = "facebook/opt-125m"
llm = LLM(model=model,
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
gpu_memory_utilization=0.3,
disable_log_stats=False)
prompts = ["This is a short prompt"]
outputs = llm.generate(prompts, sampling_params=sampling_params)
print(f"test_traces outputs is : {outputs}")
timeout = 10
if not trace_service.evt.wait(timeout):
raise TimeoutError(
f"The fake trace service didn't receive a trace within "
f"the {timeout} seconds timeout")
request = trace_service.request
assert len(request.resource_spans) == 1, (
f"Expected 1 resource span, "
f"but got {len(request.resource_spans)}")
assert len(request.resource_spans[0].scope_spans) == 1, (
f"Expected 1 scope span, "
f"but got {len(request.resource_spans[0].scope_spans)}")
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
f"Expected 1 span, "
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")
attributes = decode_attributes(
request.resource_spans[0].scope_spans[0].spans[0].attributes)
# assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
) == sampling_params.temperature
assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
) == sampling_params.max_tokens
assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
assert attributes.get(
SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids)
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
assert attributes.get(
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) > 0
assert attributes.get(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) > 0

View File

@ -1491,12 +1491,6 @@ class EngineArgs:
recommend_to_remove=False)
return False
# No OTLP observability so far.
if (self.otlp_traces_endpoint or self.collect_detailed_traces):
_raise_or_fallback(feature_name="--otlp-traces-endpoint",
recommend_to_remove=False)
return False
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
if (self.speculative_config is not None
and self.speculative_config.get("method") == "draft_model"):

View File

@ -119,6 +119,11 @@ class SpanAttributes:
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = (
"gen_ai.latency.time_in_model_execute")
GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = \
"gen_ai.latency.time_in_model_prefill"
GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode"
GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = \
"gen_ai.latency.time_in_model_inference"
def contains_trace_headers(headers: Mapping[str, str]) -> bool:

View File

@ -969,9 +969,9 @@ class Scheduler(SchedulerInterface):
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors

View File

@ -3,6 +3,7 @@
import enum
import time
from collections.abc import Mapping
from typing import Any, Optional, Union
import msgspec
@ -66,6 +67,8 @@ class EngineCoreRequest(
current_wave: int = 0
priority: int = 0
trace_headers: Optional[Mapping[str, str]] = None
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
@ -111,6 +114,7 @@ class EngineCoreOutput(
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None
trace_headers: Optional[Mapping[str, str]] = None
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0
@ -144,7 +148,7 @@ class EngineCoreOutputs(
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
#NOTE(Nick): We could consider ways to make this more compact,
# NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout
engine_index: int = 0

View File

@ -26,6 +26,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -97,6 +98,7 @@ class AsyncLLM(EngineClient):
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.log_requests = log_requests
self.log_stats = log_stats or (stat_loggers is not None)
@ -124,6 +126,11 @@ class AsyncLLM(EngineClient):
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
self.output_processor.tracer = tracer
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
@ -603,7 +610,7 @@ class AsyncLLM(EngineClient):
return self.tokenizer.get_lora_tokenizer(lora_request)
async def is_tracing_enabled(self) -> bool:
return False
return self.observability_config.otlp_traces_endpoint is not None
async def do_log_stats(
self,

View File

@ -19,6 +19,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
@ -65,6 +66,7 @@ class LLMEngine:
"Set VLLM_USE_V1=0 and file and issue on Github.")
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
@ -99,6 +101,11 @@ class LLMEngine:
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
self.output_processor.tracer = tracer
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(

View File

@ -11,6 +11,8 @@ import torch
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.sampling_params import RequestOutputKind
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
extract_trace_context)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
@ -71,7 +73,6 @@ class RequestOutputCollector:
@dataclass
class OutputProcessorOutput:
request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
reqs_to_abort: list[str]
@ -93,6 +94,9 @@ class RequestState:
arrival_time: float,
queue: Optional[RequestOutputCollector],
log_stats: bool,
top_p: Optional[float] = None,
n: Optional[int] = None,
temperature: Optional[float] = None,
):
self.request_id = request_id
self.parent_req = parent_req
@ -105,6 +109,9 @@ class RequestState:
self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer
self.max_tokens_param = max_tokens_param
self.top_p = top_p
self.n = n
self.temperature = temperature
self.is_prefilling = True
self.queue = queue
self.num_cached_tokens = 0
@ -137,10 +144,16 @@ class RequestState:
request=request,
)
max_tokens_param = sampling_params.max_tokens
top_p = sampling_params.top_p
n = sampling_params.n
temperature = sampling_params.temperature
else:
logprobs_processor = None
detokenizer = None
max_tokens_param = None
top_p = None
n = None
temperature = None
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind
@ -156,6 +169,9 @@ class RequestState:
logprobs_processor=logprobs_processor,
detokenizer=detokenizer,
max_tokens_param=max_tokens_param,
top_p=top_p,
n=n,
temperature=temperature,
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
@ -274,16 +290,13 @@ class RequestState:
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
def __init__(
self,
tokenizer: TokenizerGroup,
log_stats: bool,
):
def __init__(self, tokenizer: TokenizerGroup, log_stats: bool):
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
self.tracer: Optional[Tracer] = None
def get_num_unfinished_requests(self):
return len(self.request_states)
@ -441,7 +454,9 @@ class OutputProcessor:
# Track per-request stats
self._update_stats_from_finished(req_state, finish_reason,
iteration_stats)
if self.tracer:
self.do_tracing(engine_core_output, req_state,
iteration_stats)
self.lora_states.update_iteration_stats(iteration_stats)
return OutputProcessorOutput(
@ -449,6 +464,63 @@ class OutputProcessor:
reqs_to_abort=reqs_to_abort,
)
def do_tracing(self, engine_core_output: EngineCoreOutput,
req_state: RequestState,
iteration_stats: Optional[IterationStats]) -> None:
assert req_state.stats is not None
assert iteration_stats is not None
assert self.tracer is not None
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
trace_context = extract_trace_context(engine_core_output.trace_headers)
with (self.tracer.start_as_current_span(
"llm_request",
kind=SpanKind.SERVER,
context=trace_context,
start_time=arrival_time_nano_seconds) as span):
metrics = req_state.stats
e2e_time = iteration_stats.iteration_timestamp - \
metrics.arrival_time
queued_time = metrics.scheduled_ts - metrics.queued_ts
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
decode_time = metrics.last_token_ts - metrics.first_token_ts
inference_time = metrics.last_token_ts - metrics.scheduled_ts
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
metrics.first_token_latency)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
queued_time)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
len(req_state.prompt_token_ids))
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
metrics.num_generation_tokens)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL,
prefill_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE,
decode_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE,
inference_time)
# meta
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
req_state.request_id)
if req_state.top_p:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
req_state.top_p)
if req_state.max_tokens_param:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
req_state.max_tokens_param)
if req_state.temperature:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
req_state.temperature)
if req_state.n:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
req_state.n)
def _update_stats_from_output(self, req_state: RequestState,
engine_core_output: EngineCoreOutput,
engine_core_timestamp: Optional[float],

View File

@ -327,8 +327,6 @@ class Processor:
# TODO(woosuk): Support pooling models.
self._validate_lora(lora_request)
self._validate_params(params, lora_request)
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
@ -435,6 +433,7 @@ class Processor:
cache_salt=decoder_inputs.get("cache_salt"),
priority=priority,
data_parallel_rank=data_parallel_rank,
trace_headers=trace_headers,
)
def _validate_model_inputs(self,

View File

@ -68,6 +68,9 @@ class RequestStateStats:
first_token_ts: float = 0.0
last_token_ts: float = 0.0
# first token latency
first_token_latency: float = 0.0
@dataclass
class FinishedRequestStats:
@ -116,6 +119,7 @@ class IterationStats:
first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency)
req_stats.first_token_latency = first_token_latency
req_stats.num_generation_tokens += num_new_generation_tokens

View File

@ -3,6 +3,7 @@
import enum
import time
from collections.abc import Mapping
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@ -35,6 +36,7 @@ class Request:
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
priority: int = 0,
trace_headers: Optional[Mapping[str, str]] = None,
block_hasher: Optional[Callable[["Request"],
list["BlockHash"]]] = None,
) -> None:
@ -100,7 +102,8 @@ class Request:
# they should also be updated simultaneously.
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)
# trace_headers
self.trace_headers = trace_headers
# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
@ -136,6 +139,7 @@ class Request:
if request.sampling_params else None,
cache_salt=request.cache_salt,
priority=request.priority,
trace_headers=request.trace_headers,
block_hasher=block_hasher,
)