diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ab3cdc4ee295d..954f74c3fdaef 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -25,7 +25,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import cdiv, kill_process_tree from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async +from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, @@ -145,25 +145,30 @@ class AsyncLLM(EngineClient): """Add new request to the AsyncLLM.""" # 1) Create a new output queue for the request. - if self.output_processor.is_request_active(request_id): - raise ValueError(f"Request id {request_id} already running.") queue: asyncio.Queue[RequestOutput] = asyncio.Queue() - # 2) Convert Input --> Request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + # 2) Fan out child requests (for n>1) + parent_req = ParentRequest.from_params(request_id, params) + n = params.n if isinstance(params, SamplingParams) else 1 + for idx in range(n): + if parent_req is not None: + request_id, params = parent_req.get_child_info(idx) - # 3) Add the request to OutputProcessor (this process). - self.output_processor.add_request(request, queue) + # 3) Convert Input --> Request. + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) - # 4) Add the EngineCoreRequest to EngineCore (separate process). - await self.engine_core.add_request_async(request) + # 4) Add the request to OutputProcessor (this process). + self.output_processor.add_request(request, parent_req, idx, queue) - if self.log_requests: - logger.info("Added request %s.", request_id) + # 5) Add the EngineCoreRequest to EngineCore (separate process). + await self.engine_core.add_request_async(request) + + if self.log_requests: + logger.info("Added request %s.", request_id) return queue @@ -172,7 +177,7 @@ class AsyncLLM(EngineClient): # requests we don't need to send multiple messages to core proc, # and so we don't need multiple streams which then get # re-multiplexed in the API server anyhow. - async def _generate( + async def generate( self, prompt: PromptType, sampling_params: SamplingParams, @@ -243,30 +248,6 @@ class AsyncLLM(EngineClient): await self.abort(request_id) raise - def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> AsyncGenerator[RequestOutput, None]: - kwargs = dict(prompt=prompt, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) - if sampling_params.n is None or sampling_params.n == 1: - return self._generate(**kwargs) - else: - # Special handling for parallel sampling requests - return generate_parallel_sampling_async(generate=self._generate, - **kwargs) - async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 2e76694a7f512..99b97ac8e6c46 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer_group import ( from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager +from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -50,9 +50,6 @@ class LLMEngine: self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config - # Bookkeeping for parallel sampling requests - self.parallel_manager = SyncParallelSamplingManager() - # important: init dp group before init the engine_core self.parallel_config = vllm_config.parallel_config self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa @@ -120,8 +117,7 @@ class LLMEngine: multiprocess_mode=enable_multiprocessing) def get_num_unfinished_requests(self) -> int: - return self.parallel_manager.get_num_unfinished_requests( - self.output_processor.get_num_unfinished_requests()) + return self.output_processor.get_num_unfinished_requests() def has_unfinished_requests(self) -> bool: has_unfinished = self.output_processor.has_unfinished_requests() @@ -157,48 +153,25 @@ class LLMEngine: prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: - """Add request.""" - kwargs = dict(request_id=request_id, - prompt=prompt, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) - # Handle parallel sampling requests differently. - if params is None or isinstance(params, - PoolingParams) or params.n == 1: - self._add_request(**kwargs) - else: - # Special handling for parallel sampling requests - self.parallel_manager.add_request_parallel_sampling( - add_request=self._add_request, **kwargs) + # 1) Fan out child requests (for n>1) + parent_req = ParentRequest.from_params(request_id, params) + n = params.n if isinstance(params, SamplingParams) else 1 + for idx in range(n): + if parent_req is not None: + request_id, params = parent_req.get_child_info(idx) - def _add_request( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: - """Add request, `n=1`""" - # 1) Process raw inputs into the request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + # 2) Process raw inputs into the request. + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) - # 2) Make a new RequestState and queue. - self.output_processor.add_request(request) + # 3) Make a new RequestState and queue. + self.output_processor.add_request(request, parent_req, idx) - # 3) Add the request to EngineCore. - self.engine_core.add_request(request) + # 3) Add the request to EngineCore. + self.engine_core.add_request(request) def step(self) -> list[RequestOutput]: @@ -217,10 +190,7 @@ class LLMEngine: # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) - request_outputs = processed_outputs.request_outputs - - # 4) Process unfinished parallel sampling requests - return self.parallel_manager.step(request_outputs) + return processed_outputs.request_outputs def get_model_config(self): return self.model_config diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 22bbb8a0f5b47..4e1d1e3bf51bc 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -4,13 +4,14 @@ import asyncio from dataclasses import dataclass from typing import Optional, Union -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor +from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, RequestStateStats) @@ -27,6 +28,8 @@ class RequestState: def __init__( self, request_id: str, + parent_req: Optional[ParentRequest], + request_index: int, lora_name: Optional[str], output_kind: RequestOutputKind, prompt: Optional[str], @@ -38,6 +41,8 @@ class RequestState: log_stats: bool, ): self.request_id = request_id + self.parent_req = parent_req + self.request_index = request_index self.lora_name = lora_name self.output_kind = output_kind self.prompt = prompt @@ -56,11 +61,15 @@ class RequestState: cls, tokenizer: AnyTokenizer, request: EngineCoreRequest, + parent_req: Optional[ParentRequest], + request_index: int, queue: Optional[asyncio.Queue[RequestOutput]], log_stats: bool, ) -> "RequestState": return cls( request_id=request.request_id, + parent_req=parent_req, + request_index=request_index, lora_name=(request.lora_request.name if request.lora_request is not None else None), output_kind=request.sampling_params.output_kind, @@ -79,6 +88,88 @@ class RequestState: log_stats=log_stats, ) + def make_request_output( + self, + new_token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + ) -> Optional[RequestOutput]: + + finished = finish_reason is not None + output_kind = self.output_kind + final_only = output_kind == RequestOutputKind.FINAL_ONLY + + # In follow up, we will switch to invariant where EngineCore + # does not stream partial prefills. + if not finished and (self.is_prefilling or final_only): + # Only the final output is required in FINAL_ONLY mode. + return None + + def new_request_output(request_id: str) -> RequestOutput: + return self._new_request_output(request_id, finished) + + completion_output = self._new_completion_output( + new_token_ids, finish_reason, stop_reason) + + if self.parent_req is not None: + return self.parent_req.make_request_output(final_only, + completion_output, + new_request_output) + + request_output = new_request_output(self.request_id) + request_output.outputs.append(completion_output) + return request_output + + def _new_request_output( + self, + request_id: str, + finished: bool, + ) -> RequestOutput: + + if self.output_kind == RequestOutputKind.DELTA: + # Side effect: logprobs processor forgets prompt logprobs + prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs() + else: + prompt_logprobs = self.logprobs_processor.prompt_logprobs + + return RequestOutput( + request_id=request_id, + prompt=self.prompt, + prompt_token_ids=self.prompt_token_ids, + prompt_logprobs=prompt_logprobs, + outputs=[], + finished=finished, + ) + + def _new_completion_output( + self, + token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + ) -> CompletionOutput: + + finished = finish_reason is not None + delta = self.output_kind == RequestOutputKind.DELTA + + # Prepare text and token_ids, based on delta mode + text = self.detokenizer.get_next_output_text(finished, delta) + if not delta: + token_ids = self.detokenizer.output_token_ids + + # Prepare logprobs, based on delta mode + logprobs = self.logprobs_processor.logprobs + if delta and logprobs: + logprobs = logprobs[-len(token_ids):] + + return CompletionOutput( + index=self.request_index, + text=text, + token_ids=token_ids, + logprobs=logprobs, + cumulative_logprob=self.logprobs_processor.cumulative_logprob, + finish_reason=str(finish_reason) if finished else None, + stop_reason=stop_reason if finished else None) + class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" @@ -93,9 +184,6 @@ class OutputProcessor: self.request_states: dict[str, RequestState] = {} self.lora_states = LoRARequestStates() - def is_request_active(self, request_id: str) -> bool: - return request_id in self.request_states - def get_num_unfinished_requests(self): return len(self.request_states) @@ -114,6 +202,8 @@ class OutputProcessor: def add_request( self, request: EngineCoreRequest, + parent_req: Optional[ParentRequest] = None, + request_index: int = 0, queue: Optional[asyncio.Queue[RequestOutput]] = None, ) -> None: request_id = request.request_id @@ -123,6 +213,8 @@ class OutputProcessor: req_state = RequestState.from_new_request( tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), request=request, + parent_req=parent_req, + request_index=request_index, queue=queue, log_stats=self.log_stats) self.request_states[request_id] = req_state @@ -202,8 +294,8 @@ class OutputProcessor: req_state.logprobs_processor.update_from_output(engine_core_output) # 4) Create and handle RequestOutput objects. - if request_output := self._make_request_output( - req_state, new_token_ids, finish_reason, stop_reason): + if request_output := req_state.make_request_output( + new_token_ids, finish_reason, stop_reason): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put_nowait(request_output) @@ -211,18 +303,17 @@ class OutputProcessor: # LLMEngine: return list of RequestOutputs. request_outputs.append(request_output) - # Free completed requests. - if request_output.finished: - self.request_states.pop(req_id) - if not engine_core_output.finished: - # If req not finished in EngineCore, but Detokenizer - # detected stop string, abort needed in EngineCore. - reqs_to_abort.append(req_id) + # Free completed requests. + if finish_reason is not None: + self.request_states.pop(req_id) + if not engine_core_output.finished: + # If req not finished in EngineCore, but Detokenizer + # detected stop string, abort needed in EngineCore. + reqs_to_abort.append(req_id) - # Track per-request stats - self._update_stats_from_finished(req_state, request_output, - finish_reason, - iteration_stats) + # Track per-request stats + self._update_stats_from_finished(req_state, finish_reason, + iteration_stats) self.lora_states.update_iteration_stats(iteration_stats) @@ -249,7 +340,6 @@ class OutputProcessor: req_state.stats, lora_stats) def _update_stats_from_finished(self, req_state: RequestState, - request_output: RequestOutput, finish_reason: Optional[FinishReason], iteration_stats: Optional[IterationStats]): if iteration_stats is None: @@ -257,55 +347,8 @@ class OutputProcessor: assert finish_reason is not None assert req_state.stats is not None - iteration_stats.update_from_finished_request(finish_reason, - request_output, - req_state.stats) + iteration_stats.update_from_finished_request( + finish_reason=finish_reason, + num_prompt_tokens=len(req_state.prompt_token_ids), + req_stats=req_state.stats) self.lora_states.finish_request(req_state) - - @staticmethod - def _make_request_output( - request_state: RequestState, - new_token_ids: list[int], - finish_reason: Optional[FinishReason], - stop_reason: Union[int, str, None], - ) -> Optional[RequestOutput]: - - finished = finish_reason is not None - output_kind = request_state.output_kind - # In follow up, we will switch to invariant where EngineCore - # does not stream partial prefills. - if not finished and (request_state.is_prefilling - or output_kind == RequestOutputKind.FINAL_ONLY): - # Only the final output is required in FINAL_ONLY mode. - return None - - detokenizer = request_state.detokenizer - logprobs_processor = request_state.logprobs_processor - - delta = output_kind == RequestOutputKind.DELTA - logprobs = logprobs_processor.logprobs - if delta: - if logprobs: - logprobs = logprobs[-len(new_token_ids):] - # Side effect: logprobs processor forgets prompt logprobs - prompt_logprobs = logprobs_processor.pop_prompt_logprobs() - else: - prompt_logprobs = logprobs_processor.prompt_logprobs - - request_output = RequestOutput.new( - request_id=request_state.request_id, - prompt=request_state.prompt, - prompt_token_ids=request_state.prompt_token_ids, - text=detokenizer.get_next_output_text(finished, delta), - token_ids=new_token_ids if delta else detokenizer.output_token_ids, - logprobs=logprobs, - prompt_logprobs=prompt_logprobs, - cumulative_logprob=logprobs_processor.cumulative_logprob, - finished=finished, - ) - if finished: - completion_output = request_output.outputs[0] - completion_output.finish_reason = str(finish_reason) - completion_output.stop_reason = stop_reason - - return request_output diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 291360771b54f..adced8973b033 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,69 +1,46 @@ # SPDX-License-Identifier: Apache-2.0 -from collections.abc import AsyncGenerator, Mapping from copy import copy -from typing import Optional, Protocol, Union +from typing import Callable, Optional, Union -from vllm.inputs import PromptType -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.utils import merge_async_iterators +from vllm.sampling_params import SamplingParams -class AsyncGenerateMethodType(Protocol): - - def __call__(self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0) -> AsyncGenerator[RequestOutput, None]: - ... - - -class SyncAddRequestMethodType(Protocol): - - def __call__(self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0) -> None: - ... - - -class ParallelSamplingRequest: +class ParentRequest: """Info, state & processing for parallel sampling request. - + Store parent request ID and sampling params. Facilitate generating child request sampling params. - Transform child request outputs into parent request - outputs. - When stream mode is disabled, then `self.request_output` - aggregates child request completions. """ request_id: str sampling_params: SamplingParams + + # To aggregate child completions when not streaming + output_aggregator: Optional[RequestOutput] + + # To efficiently obtain child sampling params cached_child_sampling_params: Optional[SamplingParams] - request_output: Optional[RequestOutput] - num_finished_completions: int def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: self.request_id = request_id self.sampling_params = sampling_params + + self.output_aggregator = None self.cached_child_sampling_params = None - self.request_output = None - self.num_finished_completions = 0 + + @classmethod + def from_params( + cls, + request_id: str, + params: Union[SamplingParams, PoolingParams], + ) -> Optional['ParentRequest']: + if not isinstance(params, SamplingParams) or params.n == 1: + return None + return cls(request_id, params) def _get_child_sampling_params( self, @@ -96,47 +73,6 @@ class ParallelSamplingRequest: child_sampling_params.seed = seed + index return child_sampling_params - def _add_output( - self, - child_req_output: RequestOutput, - index: int, - ) -> None: - """Aggregate a parallel sampling child - request output. - - Non-stream-mode (`output_kind == FINAL_ONLY`) - only. Inject correct parent request ID and - completion index. - - Args: - child_req_output: a single request output - from a parallel sampling - child request. - index: index within `n` child - """ - self.num_finished_completions += 1 - new_completion = child_req_output.outputs[0] - new_completion.index = index - if self.request_output is None: - # Save the first request output; reinstate - # original request ID; metrics are not - # supported for parallel sampling - child_req_output.request_id = self.request_id - child_req_output.metrics = None - self.request_output = child_req_output - else: - # Aggregate additional completion into request output - # Note: will be sorted by index later - self.request_output.outputs.append(new_completion) - - def _get_final_request_output(self) -> RequestOutput: - """Invariant: parent completion outputs sorted by index""" - assert self.request_output is not None - self.request_output.finished = True - self.request_output.outputs = sorted(self.request_output.outputs, - key=lambda x: x.index) - return self.request_output - def get_child_info(self, index: int) -> tuple[str, SamplingParams]: """Get child request ID and sampling params. @@ -149,227 +85,35 @@ class ParallelSamplingRequest: return (f"{index}_{self.request_id}", self._get_child_sampling_params(index)) - def process_output( - self, - child_req_output: RequestOutput, - index: int, - ) -> Optional[RequestOutput]: - """Filter, aggregate and transform parallel sampling - child request outputs. - - If the parent request has `stream=false` - (`output_kind == FINAL_ONLY`), each child will also have - `output_kind == FINAL_ONLY`. All child request outputs - must be aggregated into a single request output, with - multiple completions. This request output is only returned - once `n` completions are aggregated. - - If the parent request has `stream=true` - (`output_kind == DELTA`), each child will also have - `output_kind == DELTA`. All child request outputs - must be streamed directly to the caller. - - Args: - child_req_output: a single child request output - index: index within `n` child requests - - Returns: - `None`, unless a processed request output is ready to - send back to the caller. - """ - if self.output_kind != RequestOutputKind.FINAL_ONLY: - # stream=true: return child completions immediately - child_req_output.request_id = self.request_id - child_req_output.outputs[0].index = index - if child_req_output.finished: - # Parent request is complete if all child requests are - # complete. - self.num_finished_completions += 1 - child_req_output.finished = ( - self.num_finished_completions == self.n) - return child_req_output - - # stream=false: aggregate child completions - self._add_output(child_req_output, index) - if self.num_finished_completions == self.n: - # Return aggregated request output after obtaining - # all completions - return self._get_final_request_output() - return None - - async def wrap_child_async_generator( - self, - child_gen: AsyncGenerator[RequestOutput, None], - index: int, - ) -> AsyncGenerator[RequestOutput, None]: - """Output generator for a single parallel sampling - child request. - - Each parallel sampling request triggers at - least two child requests. This generator - yields zero or more request outputs to - return to the caller, as they become - available. - - Args: - child_gen: generator for child request - outputs. - index: index within the `n` child requests - - Returns: - Yields zero or more request outputs to return - to the caller. - """ - async for out in child_gen: - if req_out := self.process_output(out, index): - yield req_out - @property def n(self) -> int: return self.sampling_params.n - @property - def output_kind(self) -> RequestOutputKind: - return self.sampling_params.output_kind - - -class SyncParallelSamplingManager: - - def __init__(self): - # Parent req ID -> parent request manager - self.parent_reqs: dict[str, ParallelSamplingRequest] = {} - # Child req ID -> (child req index, parent req ID) - self.child_reqs: dict[str, tuple[int, str]] = {} - - def _register_parent_request(self, req: ParallelSamplingRequest) -> None: - """Register parallel sampling parent request.""" - self.parent_reqs[req.request_id] = req - - def _register_child_request(self, req_id: str, child_req_id: str, - index: int) -> None: - """Register parallel sampling child request with parent. - - Args: - req_id: parent request ID - child_req_id: child request ID - index: child request index within `n` child requests - """ - self.child_reqs[child_req_id] = (index, req_id) - - def get_num_unfinished_requests(self, num_core_reqs: int) -> int: - """Get the number of unfinished requests, correcting for parallel - sampling. - - Args: - num_core_reqs: The number of unfinished requests in the engine core. - - Returns: - Number of unfinished requests, where each parallel sampling req - counts as 1 - """ - return num_core_reqs + len(self.parent_reqs) - len(self.child_reqs) - - def add_request_parallel_sampling( + def make_request_output( self, - add_request: SyncAddRequestMethodType, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: - """Add sync parallel sampling request.""" - req = ParallelSamplingRequest(request_id, params) - self._register_parent_request(req) - # Add n child requests with unique request IDs & random seeds and n=1 - for idx in range(req.n): - child_req_id, child_params = req.get_child_info(idx) - self._register_child_request(request_id, child_req_id, idx) - add_request(request_id=child_req_id, - prompt=prompt, - params=child_params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) # type: ignore + final_only: bool, + completion_output: CompletionOutput, + new_request_output: Callable[[str], RequestOutput], + ) -> Optional[RequestOutput]: + # Use an existing RequestOutput if we're aggregating + request_output = self.output_aggregator - def step( - self, - outputs: list[RequestOutput], - ) -> list[RequestOutput]: - """Build parallel sampling request outputs. - - Extract child request outputs, aggregate them - into parent request output, and return parent - output when complete. + # Make new RequestOutput otherwise + if request_output is None: + request_output = new_request_output(self.request_id) - Do not modify `n=1` requests. + # Add a new completion + request_output.outputs.append(completion_output) - Args: - outputs: step request outputs. Mix of child request - outputs & `n=1` request outputs. + # If not streaming, aggregate until all child requests complete + if final_only and len(request_output.outputs) != self.n: + self.output_aggregator = request_output + return None - Return: - List of parallel sampling parent request outputs & - unmodified `n=1` request outputs passed-thru from input. - """ - if not (self.parent_reqs and outputs): - # Return unmodified - return outputs - agg_outputs = [] - for output in outputs: - req_id = output.request_id - if child_req_entry := self.child_reqs.get(req_id, None): - # For each parallel sampling child request output: - (index, parent_req_id) = child_req_entry - req = self.parent_reqs[parent_req_id] - # Update parallel sampling request - if out := req.process_output(output, index): - # Return parent request output if complete; - # cleanup parent request bookkeeping. - agg_outputs.append(out) - del self.parent_reqs[parent_req_id] - # Cleanup child request bookkeeping. - del self.child_reqs[req_id] - else: - # Not a parallel sampling request output - agg_outputs.append(output) - return agg_outputs + # We're done aggregating + self.output_aggregator = None - -async def generate_parallel_sampling_async( - generate: AsyncGenerateMethodType, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -) -> AsyncGenerator[RequestOutput, None]: - """Generate completions for async parallel sampling requests.""" - parent_req = ParallelSamplingRequest(request_id, sampling_params) - - # Aggregate generators for n child requests - gens: list[AsyncGenerator[RequestOutput, None]] = [] - for idx in range(parent_req.n): - child_req_id, child_params = parent_req.get_child_info(idx) - child_gen = generate( - prompt=prompt, - sampling_params=child_params, - request_id=child_req_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ) # type: ignore - gen = parent_req.wrap_child_async_generator(child_gen, idx) - gens.append(gen) - - # Merge generators - async for _, out in merge_async_iterators(*gens): - yield out + # Parent completion output list must be sorted by index + request_output.outputs = sorted(request_output.outputs, + key=lambda x: x.index) + return request_output diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 625edb607467b..abdca95670e11 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -5,7 +5,6 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from vllm.outputs import RequestOutput from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.output_processor import RequestState @@ -150,7 +149,7 @@ class IterationStats: self.num_preempted_reqs += 1 def update_from_finished_request(self, finish_reason: "FinishReason", - request_output: "RequestOutput", + num_prompt_tokens: int, req_stats: RequestStateStats): e2e_latency = self._time_since(req_stats.arrival_time) @@ -172,7 +171,7 @@ class IterationStats: finished_req = \ FinishedRequestStats(finish_reason=finish_reason, e2e_latency=e2e_latency, - num_prompt_tokens=len(request_output.prompt_token_ids), + num_prompt_tokens=num_prompt_tokens, num_generation_tokens=req_stats.num_generation_tokens, queued_time=queued_time, prefill_time=prefill_time,