mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 20:15:19 +08:00
609 lines
23 KiB
Python
609 lines
23 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import asyncio
|
|
from collections.abc import Iterable
|
|
from dataclasses import dataclass
|
|
from typing import Any, cast
|
|
|
|
import torch
|
|
|
|
from vllm.outputs import (
|
|
CompletionOutput,
|
|
PoolingOutput,
|
|
PoolingRequestOutput,
|
|
RequestOutput,
|
|
)
|
|
from vllm.sampling_params import RequestOutputKind
|
|
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
|
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
|
from vllm.v1.engine.logprobs import LogprobsProcessor
|
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
|
from vllm.v1.metrics.stats import (
|
|
IterationStats,
|
|
LoRARequestStates,
|
|
RequestStateStats,
|
|
SchedulerStats,
|
|
)
|
|
|
|
|
|
class RequestOutputCollector:
|
|
"""
|
|
Collects streamed RequestOutputs per individual request,
|
|
for hand-off to the consuming asyncio generate task.
|
|
|
|
When streaming deltas, RequestOutputs are merged if the
|
|
producer gets ahead of the consumer.
|
|
"""
|
|
|
|
def __init__(self, output_kind: RequestOutputKind):
|
|
self.aggregate = output_kind == RequestOutputKind.DELTA
|
|
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
|
|
self.ready = asyncio.Event()
|
|
|
|
def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
|
|
"""Non-blocking put operation."""
|
|
if self.output is None or isinstance(output, Exception):
|
|
self.output = output
|
|
self.ready.set()
|
|
elif isinstance(self.output, RequestOutput) and isinstance(
|
|
output, RequestOutput
|
|
):
|
|
# This ensures that request outputs with different request indexes
|
|
# (if n > 1) do not override each other.
|
|
self.output.add(output, aggregate=self.aggregate)
|
|
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
|
|
output, PoolingRequestOutput
|
|
):
|
|
self.output = output
|
|
|
|
async def get(self) -> RequestOutput | PoolingRequestOutput:
|
|
"""Get operation blocks on put event."""
|
|
while (output := self.output) is None:
|
|
await self.ready.wait()
|
|
self.output = None
|
|
self.ready.clear()
|
|
if isinstance(output, Exception):
|
|
raise output
|
|
return output
|
|
|
|
def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
|
|
"""Non-blocking get operation."""
|
|
output = self.output
|
|
if output is not None:
|
|
self.output = None
|
|
self.ready.clear()
|
|
if isinstance(output, Exception):
|
|
raise output
|
|
return output
|
|
|
|
|
|
@dataclass
|
|
class OutputProcessorOutput:
|
|
request_outputs: list[RequestOutput | PoolingRequestOutput]
|
|
reqs_to_abort: list[str]
|
|
|
|
|
|
class RequestState:
|
|
def __init__(
|
|
self,
|
|
request_id: str,
|
|
parent_req: ParentRequest | None,
|
|
request_index: int,
|
|
lora_name: str | None,
|
|
output_kind: RequestOutputKind,
|
|
prompt: str | None,
|
|
prompt_token_ids: list[int] | None,
|
|
prompt_embeds: torch.Tensor | None,
|
|
logprobs_processor: LogprobsProcessor | None,
|
|
detokenizer: IncrementalDetokenizer | None,
|
|
max_tokens_param: int | None,
|
|
arrival_time: float,
|
|
queue: RequestOutputCollector | None,
|
|
log_stats: bool,
|
|
top_p: float | None = None,
|
|
n: int | None = None,
|
|
temperature: float | None = None,
|
|
):
|
|
self.request_id = request_id
|
|
self.parent_req = parent_req
|
|
self.request_index = request_index
|
|
self.lora_name = lora_name
|
|
self.output_kind = output_kind
|
|
self.prompt = prompt
|
|
self.prompt_token_ids = prompt_token_ids
|
|
self.prompt_embeds = prompt_embeds
|
|
self.prompt_len = length_from_prompt_token_ids_or_embeds(
|
|
self.prompt_token_ids, self.prompt_embeds
|
|
)
|
|
self.logprobs_processor = logprobs_processor
|
|
self.detokenizer = detokenizer
|
|
self.max_tokens_param = max_tokens_param
|
|
self.top_p = top_p
|
|
self.n = n
|
|
self.temperature = temperature
|
|
self.is_prefilling = True
|
|
self.queue = queue
|
|
self.num_cached_tokens = 0
|
|
|
|
self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
|
|
|
|
@classmethod
|
|
def from_new_request(
|
|
cls,
|
|
tokenizer: AnyTokenizer,
|
|
request: EngineCoreRequest,
|
|
prompt: str | None,
|
|
parent_req: ParentRequest | None,
|
|
request_index: int,
|
|
queue: RequestOutputCollector | None,
|
|
log_stats: bool,
|
|
) -> "RequestState":
|
|
if sampling_params := request.sampling_params:
|
|
if not sampling_params.detokenize:
|
|
tokenizer = None
|
|
output_kind = sampling_params.output_kind
|
|
logprobs_processor = LogprobsProcessor.from_new_request(
|
|
tokenizer=tokenizer,
|
|
request=request,
|
|
)
|
|
detokenizer = IncrementalDetokenizer.from_new_request(
|
|
tokenizer=tokenizer,
|
|
request=request,
|
|
)
|
|
max_tokens_param = sampling_params.max_tokens
|
|
top_p = sampling_params.top_p
|
|
n = sampling_params.n
|
|
temperature = sampling_params.temperature
|
|
else:
|
|
logprobs_processor = None
|
|
detokenizer = None
|
|
max_tokens_param = None
|
|
top_p = None
|
|
n = None
|
|
temperature = None
|
|
assert request.pooling_params is not None
|
|
output_kind = request.pooling_params.output_kind
|
|
|
|
return cls(
|
|
request_id=request.request_id,
|
|
parent_req=parent_req,
|
|
request_index=request_index,
|
|
lora_name=(
|
|
request.lora_request.name if request.lora_request is not None else None
|
|
),
|
|
output_kind=output_kind,
|
|
prompt=prompt,
|
|
prompt_token_ids=request.prompt_token_ids,
|
|
prompt_embeds=request.prompt_embeds,
|
|
logprobs_processor=logprobs_processor,
|
|
detokenizer=detokenizer,
|
|
max_tokens_param=max_tokens_param,
|
|
top_p=top_p,
|
|
n=n,
|
|
temperature=temperature,
|
|
arrival_time=request.arrival_time,
|
|
queue=queue,
|
|
log_stats=log_stats,
|
|
)
|
|
|
|
def make_request_output(
|
|
self,
|
|
new_token_ids: list[int],
|
|
pooling_output: torch.Tensor | None,
|
|
finish_reason: FinishReason | None,
|
|
stop_reason: int | str | None,
|
|
kv_transfer_params: dict[str, Any] | None = None,
|
|
) -> RequestOutput | PoolingRequestOutput | None:
|
|
finished = finish_reason is not None
|
|
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
|
|
|
if not finished and final_only:
|
|
# Only the final output is required in FINAL_ONLY mode.
|
|
return None
|
|
|
|
request_id = self.request_id
|
|
if pooling_output is not None:
|
|
return self._new_request_output(
|
|
request_id, [self._new_pooling_output(pooling_output)], finished
|
|
)
|
|
|
|
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
|
|
|
|
if self.parent_req is None:
|
|
outputs = [output]
|
|
else:
|
|
request_id, outputs, finished = self.parent_req.get_outputs(
|
|
request_id, output
|
|
)
|
|
if not outputs:
|
|
return None
|
|
|
|
return self._new_request_output(
|
|
request_id, outputs, finished, kv_transfer_params
|
|
)
|
|
|
|
def _new_request_output(
|
|
self,
|
|
request_id: str,
|
|
outputs: list[CompletionOutput] | list[PoolingOutput],
|
|
finished: bool,
|
|
kv_transfer_params: dict[str, Any] | None = None,
|
|
) -> RequestOutput | PoolingRequestOutput:
|
|
first_output = outputs[0]
|
|
if isinstance(first_output, PoolingOutput):
|
|
assert len(outputs) == 1
|
|
# Prompt embeddings are currently not supported by pooling requests.
|
|
assert self.prompt_token_ids is not None
|
|
return PoolingRequestOutput(
|
|
request_id=request_id,
|
|
outputs=first_output,
|
|
num_cached_tokens=self.num_cached_tokens,
|
|
prompt_token_ids=self.prompt_token_ids,
|
|
finished=finished,
|
|
)
|
|
assert self.logprobs_processor is not None
|
|
if self.output_kind == RequestOutputKind.DELTA:
|
|
# Side effect: logprobs processor forgets prompt logprobs
|
|
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
|
|
else:
|
|
prompt_logprobs = self.logprobs_processor.prompt_logprobs
|
|
|
|
# If prompt embeds were used, put placeholder prompt token ids
|
|
prompt_token_ids = self.prompt_token_ids
|
|
if prompt_token_ids is None and self.prompt_embeds is not None:
|
|
prompt_token_ids = [0] * len(self.prompt_embeds)
|
|
|
|
return RequestOutput(
|
|
request_id=request_id,
|
|
prompt=self.prompt,
|
|
prompt_token_ids=prompt_token_ids,
|
|
prompt_logprobs=prompt_logprobs,
|
|
outputs=cast(list[CompletionOutput], outputs),
|
|
finished=finished,
|
|
kv_transfer_params=kv_transfer_params,
|
|
num_cached_tokens=self.num_cached_tokens,
|
|
metrics=self.stats,
|
|
)
|
|
|
|
def _new_completion_output(
|
|
self,
|
|
token_ids: list[int],
|
|
finish_reason: FinishReason | None,
|
|
stop_reason: int | str | None,
|
|
) -> CompletionOutput:
|
|
assert self.detokenizer is not None
|
|
assert self.logprobs_processor is not None
|
|
finished = finish_reason is not None
|
|
delta = self.output_kind == RequestOutputKind.DELTA
|
|
|
|
# Prepare text and token_ids, based on delta mode
|
|
text = self.detokenizer.get_next_output_text(finished, delta)
|
|
if not delta:
|
|
token_ids = self.detokenizer.output_token_ids
|
|
|
|
# Prepare logprobs, based on delta mode
|
|
logprobs = self.logprobs_processor.logprobs
|
|
if delta and logprobs:
|
|
logprobs = logprobs[-len(token_ids) :]
|
|
|
|
return CompletionOutput(
|
|
index=self.request_index,
|
|
text=text,
|
|
token_ids=token_ids,
|
|
logprobs=logprobs,
|
|
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
|
|
finish_reason=str(finish_reason) if finished else None,
|
|
stop_reason=stop_reason if finished else None,
|
|
)
|
|
|
|
def _new_pooling_output(
|
|
self,
|
|
pooling_output: torch.Tensor,
|
|
) -> PoolingOutput:
|
|
return PoolingOutput(data=pooling_output)
|
|
|
|
|
|
class OutputProcessor:
|
|
"""Process EngineCoreOutputs into RequestOutputs."""
|
|
|
|
def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
|
|
self.log_stats = log_stats
|
|
self.tokenizer = tokenizer
|
|
self.request_states: dict[str, RequestState] = {}
|
|
self.parent_requests: dict[str, ParentRequest] = {}
|
|
self.lora_states = LoRARequestStates(log_stats)
|
|
self.tracer: Tracer | None = None
|
|
|
|
def get_num_unfinished_requests(self):
|
|
return len(self.request_states)
|
|
|
|
def has_unfinished_requests(self) -> bool:
|
|
return len(self.request_states) > 0
|
|
|
|
def propagate_error(self, e: Exception):
|
|
"""Propagate error to all generate() tasks."""
|
|
|
|
for _, state in self.request_states.items():
|
|
assert state.queue is not None
|
|
state.queue.put(e)
|
|
|
|
def abort_requests(
|
|
self,
|
|
request_ids: Iterable[str],
|
|
) -> list[str]:
|
|
request_ids_to_abort = []
|
|
for request_id in request_ids:
|
|
req_state = self.request_states.pop(request_id, None)
|
|
if req_state is not None:
|
|
self.lora_states.request_finished(request_id, req_state.lora_name)
|
|
request_ids_to_abort.append(request_id)
|
|
# Produce final abort output.
|
|
if req_state.queue is not None and (
|
|
request_output := req_state.make_request_output(
|
|
new_token_ids=[],
|
|
# Set pooling_output is not None to
|
|
# correctly enter the abort pooling branch
|
|
pooling_output=torch.randn(0, device="cpu")
|
|
if req_state.detokenizer is None
|
|
else None,
|
|
finish_reason=FinishReason.ABORT,
|
|
stop_reason=None,
|
|
kv_transfer_params=None,
|
|
)
|
|
):
|
|
req_state.queue.put(request_output)
|
|
elif parent := self.parent_requests.get(request_id):
|
|
# Abort children prior to removing the parent.
|
|
if parent.child_requests:
|
|
child_reqs = list(parent.child_requests)
|
|
child_reqs = self.abort_requests(child_reqs)
|
|
request_ids_to_abort.extend(child_reqs)
|
|
self.parent_requests.pop(request_id, None)
|
|
return request_ids_to_abort
|
|
|
|
def add_request(
|
|
self,
|
|
request: EngineCoreRequest,
|
|
prompt: str | None,
|
|
parent_req: ParentRequest | None = None,
|
|
request_index: int = 0,
|
|
queue: RequestOutputCollector | None = None,
|
|
) -> None:
|
|
request_id = request.request_id
|
|
if request_id in self.request_states:
|
|
raise ValueError(f"Request id {request_id} already running.")
|
|
|
|
req_state = RequestState.from_new_request(
|
|
tokenizer=self.tokenizer,
|
|
request=request,
|
|
prompt=prompt,
|
|
parent_req=parent_req,
|
|
request_index=request_index,
|
|
queue=queue,
|
|
log_stats=self.log_stats,
|
|
)
|
|
self.request_states[request_id] = req_state
|
|
if parent_req:
|
|
self.parent_requests[parent_req.request_id] = parent_req
|
|
|
|
def process_outputs(
|
|
self,
|
|
engine_core_outputs: list[EngineCoreOutput],
|
|
engine_core_timestamp: float | None = None,
|
|
iteration_stats: IterationStats | None = None,
|
|
) -> OutputProcessorOutput:
|
|
"""
|
|
Process the EngineCoreOutputs:
|
|
1) Compute stats for logging
|
|
2) Detokenize
|
|
3) Create and handle RequestOutput objects:
|
|
* If there is a queue (for usage with AsyncLLM),
|
|
put the RequestOutput objects into the queue for
|
|
handling by the per-request generate() tasks.
|
|
|
|
* If there is no queue (for usage with LLMEngine),
|
|
return a list of RequestOutput objects.
|
|
|
|
NOTE FOR DEVELOPERS
|
|
|
|
vLLM V1 minimizes the number of python loops over the full
|
|
batch to ensure system overheads are minimized. This is the
|
|
only function that should loop over EngineCoreOutputs.
|
|
|
|
If you need to touch every element of the batch, do it from
|
|
within the loop below.
|
|
"""
|
|
|
|
request_outputs: list[RequestOutput | PoolingRequestOutput] = []
|
|
reqs_to_abort: list[str] = []
|
|
for engine_core_output in engine_core_outputs:
|
|
req_id = engine_core_output.request_id
|
|
req_state = self.request_states.get(req_id)
|
|
if req_state is None:
|
|
# Ignore output for already-aborted request.
|
|
continue
|
|
|
|
# 1) Compute stats for this iteration.
|
|
self._update_stats_from_output(
|
|
req_state, engine_core_output, engine_core_timestamp, iteration_stats
|
|
)
|
|
|
|
new_token_ids = engine_core_output.new_token_ids
|
|
pooling_output = engine_core_output.pooling_output
|
|
finish_reason = engine_core_output.finish_reason
|
|
stop_reason = engine_core_output.stop_reason
|
|
kv_transfer_params = engine_core_output.kv_transfer_params
|
|
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
|
|
req_state.is_prefilling = False
|
|
|
|
if pooling_output is None:
|
|
assert req_state.detokenizer is not None
|
|
assert req_state.logprobs_processor is not None
|
|
# 2) Detokenize the token ids into text and perform stop checks.
|
|
stop_string = req_state.detokenizer.update(
|
|
new_token_ids, finish_reason == FinishReason.STOP
|
|
)
|
|
if stop_string:
|
|
finish_reason = FinishReason.STOP
|
|
stop_reason = stop_string
|
|
|
|
# 3) Compute sample and prompt logprobs for request,
|
|
# if required.
|
|
req_state.logprobs_processor.update_from_output(engine_core_output)
|
|
|
|
# 4) Create and handle RequestOutput objects.
|
|
if request_output := req_state.make_request_output(
|
|
new_token_ids,
|
|
pooling_output,
|
|
finish_reason,
|
|
stop_reason,
|
|
kv_transfer_params,
|
|
):
|
|
if req_state.queue is not None:
|
|
# AsyncLLM: put into queue for handling by generate().
|
|
req_state.queue.put(request_output)
|
|
else:
|
|
# LLMEngine: return list of RequestOutputs.
|
|
request_outputs.append(request_output)
|
|
|
|
# Free completed requests.
|
|
if finish_reason is not None:
|
|
self.request_states.pop(req_id)
|
|
# Remove parent request if applicable.
|
|
parent_req = req_state.parent_req
|
|
if parent_req and not parent_req.child_requests:
|
|
self.parent_requests.pop(parent_req.request_id, None)
|
|
if not engine_core_output.finished:
|
|
# If req not finished in EngineCore, but Detokenizer
|
|
# detected stop string, abort needed in EngineCore.
|
|
reqs_to_abort.append(req_id)
|
|
|
|
# Track per-request stats
|
|
self._update_stats_from_finished(
|
|
req_state, finish_reason, iteration_stats
|
|
)
|
|
if self.tracer:
|
|
self.do_tracing(engine_core_output, req_state, iteration_stats)
|
|
|
|
return OutputProcessorOutput(
|
|
request_outputs=request_outputs,
|
|
reqs_to_abort=reqs_to_abort,
|
|
)
|
|
|
|
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
|
|
self.lora_states.update_scheduler_stats(scheduler_stats)
|
|
|
|
def do_tracing(
|
|
self,
|
|
engine_core_output: EngineCoreOutput,
|
|
req_state: RequestState,
|
|
iteration_stats: IterationStats | None,
|
|
) -> None:
|
|
assert req_state.stats is not None
|
|
assert iteration_stats is not None
|
|
assert self.tracer is not None
|
|
|
|
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
|
|
trace_context = extract_trace_context(engine_core_output.trace_headers)
|
|
prompt_length = length_from_prompt_token_ids_or_embeds(
|
|
req_state.prompt_token_ids, req_state.prompt_embeds
|
|
)
|
|
with self.tracer.start_as_current_span(
|
|
"llm_request",
|
|
kind=SpanKind.SERVER,
|
|
context=trace_context,
|
|
start_time=arrival_time_nano_seconds,
|
|
) as span:
|
|
metrics = req_state.stats
|
|
e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
|
|
queued_time = metrics.scheduled_ts - metrics.queued_ts
|
|
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
|
|
decode_time = metrics.last_token_ts - metrics.first_token_ts
|
|
inference_time = metrics.last_token_ts - metrics.scheduled_ts
|
|
span.set_attribute(
|
|
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
|
|
metrics.first_token_latency,
|
|
)
|
|
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
|
|
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
|
|
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
|
|
span.set_attribute(
|
|
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
|
|
metrics.num_generation_tokens,
|
|
)
|
|
span.set_attribute(
|
|
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
|
|
)
|
|
span.set_attribute(
|
|
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
|
|
)
|
|
span.set_attribute(
|
|
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
|
|
)
|
|
|
|
# meta
|
|
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
|
|
if req_state.top_p:
|
|
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
|
|
if req_state.max_tokens_param:
|
|
span.set_attribute(
|
|
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
|
|
)
|
|
if req_state.temperature:
|
|
span.set_attribute(
|
|
SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
|
|
)
|
|
if req_state.n:
|
|
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
|
|
|
|
def _update_stats_from_output(
|
|
self,
|
|
req_state: RequestState,
|
|
engine_core_output: EngineCoreOutput,
|
|
engine_core_timestamp: float | None,
|
|
iteration_stats: IterationStats | None,
|
|
):
|
|
if iteration_stats is None:
|
|
return
|
|
|
|
assert engine_core_timestamp is not None
|
|
assert req_state.stats is not None
|
|
iteration_stats.update_from_output(
|
|
engine_core_output,
|
|
engine_core_timestamp,
|
|
req_state.is_prefilling,
|
|
req_state.prompt_len,
|
|
req_state.stats,
|
|
self.lora_states,
|
|
req_state.lora_name,
|
|
)
|
|
|
|
def _update_stats_from_finished(
|
|
self,
|
|
req_state: RequestState,
|
|
finish_reason: FinishReason | None,
|
|
iteration_stats: IterationStats | None,
|
|
):
|
|
if iteration_stats is None:
|
|
return
|
|
|
|
assert finish_reason is not None
|
|
assert req_state.stats is not None
|
|
iteration_stats.update_from_finished_request(
|
|
finish_reason=finish_reason,
|
|
num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
|
|
req_state.prompt_token_ids, req_state.prompt_embeds
|
|
),
|
|
max_tokens_param=req_state.max_tokens_param,
|
|
req_stats=req_state.stats,
|
|
)
|
|
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
|
|
|
|
ParentRequest.observe_finished_request(
|
|
req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
|
|
)
|