vllm/vllm/v1/engine/output_processor.py
Mark McLoughlin 6f7de33bed
[Metrics] Refactor LoRA state tracking (#26801)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
2025-11-10 16:34:36 +08:00

609 lines
23 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, cast
import torch
from vllm.outputs import (
CompletionOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.sampling_params import RequestOutputKind
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import (
IterationStats,
LoRARequestStates,
RequestStateStats,
SchedulerStats,
)
class RequestOutputCollector:
"""
Collects streamed RequestOutputs per individual request,
for hand-off to the consuming asyncio generate task.
When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer.
"""
def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
self.ready = asyncio.Event()
def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
self.output = output
self.ready.set()
elif isinstance(self.output, RequestOutput) and isinstance(
output, RequestOutput
):
# This ensures that request outputs with different request indexes
# (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate)
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
output, PoolingRequestOutput
):
self.output = output
async def get(self) -> RequestOutput | PoolingRequestOutput:
"""Get operation blocks on put event."""
while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
"""Non-blocking get operation."""
output = self.output
if output is not None:
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
@dataclass
class OutputProcessorOutput:
request_outputs: list[RequestOutput | PoolingRequestOutput]
reqs_to_abort: list[str]
class RequestState:
def __init__(
self,
request_id: str,
parent_req: ParentRequest | None,
request_index: int,
lora_name: str | None,
output_kind: RequestOutputKind,
prompt: str | None,
prompt_token_ids: list[int] | None,
prompt_embeds: torch.Tensor | None,
logprobs_processor: LogprobsProcessor | None,
detokenizer: IncrementalDetokenizer | None,
max_tokens_param: int | None,
arrival_time: float,
queue: RequestOutputCollector | None,
log_stats: bool,
top_p: float | None = None,
n: int | None = None,
temperature: float | None = None,
):
self.request_id = request_id
self.parent_req = parent_req
self.request_index = request_index
self.lora_name = lora_name
self.output_kind = output_kind
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_embeds = prompt_embeds
self.prompt_len = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
)
self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer
self.max_tokens_param = max_tokens_param
self.top_p = top_p
self.n = n
self.temperature = temperature
self.is_prefilling = True
self.queue = queue
self.num_cached_tokens = 0
self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
request: EngineCoreRequest,
prompt: str | None,
parent_req: ParentRequest | None,
request_index: int,
queue: RequestOutputCollector | None,
log_stats: bool,
) -> "RequestState":
if sampling_params := request.sampling_params:
if not sampling_params.detokenize:
tokenizer = None
output_kind = sampling_params.output_kind
logprobs_processor = LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
)
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
)
max_tokens_param = sampling_params.max_tokens
top_p = sampling_params.top_p
n = sampling_params.n
temperature = sampling_params.temperature
else:
logprobs_processor = None
detokenizer = None
max_tokens_param = None
top_p = None
n = None
temperature = None
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind
return cls(
request_id=request.request_id,
parent_req=parent_req,
request_index=request_index,
lora_name=(
request.lora_request.name if request.lora_request is not None else None
),
output_kind=output_kind,
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
prompt_embeds=request.prompt_embeds,
logprobs_processor=logprobs_processor,
detokenizer=detokenizer,
max_tokens_param=max_tokens_param,
top_p=top_p,
n=n,
temperature=temperature,
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
)
def make_request_output(
self,
new_token_ids: list[int],
pooling_output: torch.Tensor | None,
finish_reason: FinishReason | None,
stop_reason: int | str | None,
kv_transfer_params: dict[str, Any] | None = None,
) -> RequestOutput | PoolingRequestOutput | None:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
if not finished and final_only:
# Only the final output is required in FINAL_ONLY mode.
return None
request_id = self.request_id
if pooling_output is not None:
return self._new_request_output(
request_id, [self._new_pooling_output(pooling_output)], finished
)
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
if self.parent_req is None:
outputs = [output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, output
)
if not outputs:
return None
return self._new_request_output(
request_id, outputs, finished, kv_transfer_params
)
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput] | list[PoolingOutput],
finished: bool,
kv_transfer_params: dict[str, Any] | None = None,
) -> RequestOutput | PoolingRequestOutput:
first_output = outputs[0]
if isinstance(first_output, PoolingOutput):
assert len(outputs) == 1
# Prompt embeddings are currently not supported by pooling requests.
assert self.prompt_token_ids is not None
return PoolingRequestOutput(
request_id=request_id,
outputs=first_output,
num_cached_tokens=self.num_cached_tokens,
prompt_token_ids=self.prompt_token_ids,
finished=finished,
)
assert self.logprobs_processor is not None
if self.output_kind == RequestOutputKind.DELTA:
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
else:
prompt_logprobs = self.logprobs_processor.prompt_logprobs
# If prompt embeds were used, put placeholder prompt token ids
prompt_token_ids = self.prompt_token_ids
if prompt_token_ids is None and self.prompt_embeds is not None:
prompt_token_ids = [0] * len(self.prompt_embeds)
return RequestOutput(
request_id=request_id,
prompt=self.prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=cast(list[CompletionOutput], outputs),
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=self.num_cached_tokens,
metrics=self.stats,
)
def _new_completion_output(
self,
token_ids: list[int],
finish_reason: FinishReason | None,
stop_reason: int | str | None,
) -> CompletionOutput:
assert self.detokenizer is not None
assert self.logprobs_processor is not None
finished = finish_reason is not None
delta = self.output_kind == RequestOutputKind.DELTA
# Prepare text and token_ids, based on delta mode
text = self.detokenizer.get_next_output_text(finished, delta)
if not delta:
token_ids = self.detokenizer.output_token_ids
# Prepare logprobs, based on delta mode
logprobs = self.logprobs_processor.logprobs
if delta and logprobs:
logprobs = logprobs[-len(token_ids) :]
return CompletionOutput(
index=self.request_index,
text=text,
token_ids=token_ids,
logprobs=logprobs,
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
finish_reason=str(finish_reason) if finished else None,
stop_reason=stop_reason if finished else None,
)
def _new_pooling_output(
self,
pooling_output: torch.Tensor,
) -> PoolingOutput:
return PoolingOutput(data=pooling_output)
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates(log_stats)
self.tracer: Tracer | None = None
def get_num_unfinished_requests(self):
return len(self.request_states)
def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0
def propagate_error(self, e: Exception):
"""Propagate error to all generate() tasks."""
for _, state in self.request_states.items():
assert state.queue is not None
state.queue.put(e)
def abort_requests(
self,
request_ids: Iterable[str],
) -> list[str]:
request_ids_to_abort = []
for request_id in request_ids:
req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.request_finished(request_id, req_state.lora_name)
request_ids_to_abort.append(request_id)
# Produce final abort output.
if req_state.queue is not None and (
request_output := req_state.make_request_output(
new_token_ids=[],
# Set pooling_output is not None to
# correctly enter the abort pooling branch
pooling_output=torch.randn(0, device="cpu")
if req_state.detokenizer is None
else None,
finish_reason=FinishReason.ABORT,
stop_reason=None,
kv_transfer_params=None,
)
):
req_state.queue.put(request_output)
elif parent := self.parent_requests.get(request_id):
# Abort children prior to removing the parent.
if parent.child_requests:
child_reqs = list(parent.child_requests)
child_reqs = self.abort_requests(child_reqs)
request_ids_to_abort.extend(child_reqs)
self.parent_requests.pop(request_id, None)
return request_ids_to_abort
def add_request(
self,
request: EngineCoreRequest,
prompt: str | None,
parent_req: ParentRequest | None = None,
request_index: int = 0,
queue: RequestOutputCollector | None = None,
) -> None:
request_id = request.request_id
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")
req_state = RequestState.from_new_request(
tokenizer=self.tokenizer,
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,
log_stats=self.log_stats,
)
self.request_states[request_id] = req_state
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
def process_outputs(
self,
engine_core_outputs: list[EngineCoreOutput],
engine_core_timestamp: float | None = None,
iteration_stats: IterationStats | None = None,
) -> OutputProcessorOutput:
"""
Process the EngineCoreOutputs:
1) Compute stats for logging
2) Detokenize
3) Create and handle RequestOutput objects:
* If there is a queue (for usage with AsyncLLM),
put the RequestOutput objects into the queue for
handling by the per-request generate() tasks.
* If there is no queue (for usage with LLMEngine),
return a list of RequestOutput objects.
NOTE FOR DEVELOPERS
vLLM V1 minimizes the number of python loops over the full
batch to ensure system overheads are minimized. This is the
only function that should loop over EngineCoreOutputs.
If you need to touch every element of the batch, do it from
within the loop below.
"""
request_outputs: list[RequestOutput | PoolingRequestOutput] = []
reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
req_state = self.request_states.get(req_id)
if req_state is None:
# Ignore output for already-aborted request.
continue
# 1) Compute stats for this iteration.
self._update_stats_from_output(
req_state, engine_core_output, engine_core_timestamp, iteration_stats
)
new_token_ids = engine_core_output.new_token_ids
pooling_output = engine_core_output.pooling_output
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
if pooling_output is None:
assert req_state.detokenizer is not None
assert req_state.logprobs_processor is not None
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP
)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request,
# if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids,
pooling_output,
finish_reason,
stop_reason,
kv_transfer_params,
):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
else:
# LLMEngine: return list of RequestOutputs.
request_outputs.append(request_output)
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
# Track per-request stats
self._update_stats_from_finished(
req_state, finish_reason, iteration_stats
)
if self.tracer:
self.do_tracing(engine_core_output, req_state, iteration_stats)
return OutputProcessorOutput(
request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort,
)
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
self.lora_states.update_scheduler_stats(scheduler_stats)
def do_tracing(
self,
engine_core_output: EngineCoreOutput,
req_state: RequestState,
iteration_stats: IterationStats | None,
) -> None:
assert req_state.stats is not None
assert iteration_stats is not None
assert self.tracer is not None
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
trace_context = extract_trace_context(engine_core_output.trace_headers)
prompt_length = length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds
)
with self.tracer.start_as_current_span(
"llm_request",
kind=SpanKind.SERVER,
context=trace_context,
start_time=arrival_time_nano_seconds,
) as span:
metrics = req_state.stats
e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
queued_time = metrics.scheduled_ts - metrics.queued_ts
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
decode_time = metrics.last_token_ts - metrics.first_token_ts
inference_time = metrics.last_token_ts - metrics.scheduled_ts
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
metrics.first_token_latency,
)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
span.set_attribute(
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
metrics.num_generation_tokens,
)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
)
# meta
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
if req_state.top_p:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
if req_state.max_tokens_param:
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
)
if req_state.temperature:
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
)
if req_state.n:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
def _update_stats_from_output(
self,
req_state: RequestState,
engine_core_output: EngineCoreOutput,
engine_core_timestamp: float | None,
iteration_stats: IterationStats | None,
):
if iteration_stats is None:
return
assert engine_core_timestamp is not None
assert req_state.stats is not None
iteration_stats.update_from_output(
engine_core_output,
engine_core_timestamp,
req_state.is_prefilling,
req_state.prompt_len,
req_state.stats,
self.lora_states,
req_state.lora_name,
)
def _update_stats_from_finished(
self,
req_state: RequestState,
finish_reason: FinishReason | None,
iteration_stats: IterationStats | None,
):
if iteration_stats is None:
return
assert finish_reason is not None
assert req_state.stats is not None
iteration_stats.update_from_finished_request(
finish_reason=finish_reason,
num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds
),
max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats,
)
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
ParentRequest.observe_finished_request(
req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
)