diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 3f621d77c024..ce2aae77108d 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -72,6 +72,14 @@ class EngineCoreRequest( trace_headers: Mapping[str, str] | None = None + @property + def params(self) -> SamplingParams | PoolingParams: + """Return the processed params (sampling or pooling).""" + if self.sampling_params is not None: + return self.sampling_params + assert self.pooling_params is not None + return self.pooling_params + class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index c64b3cccfc65..55087baadff9 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -321,14 +321,15 @@ class AsyncLLM(EngineClient): elif isinstance(prompt, Mapping): prompt_text = cast(str | None, prompt.get("prompt")) + # Use cloned params that may have been updated in process_inputs() + params = request.params + if is_pooling or params.n == 1: await self._add_request(request, prompt_text, None, 0, queue) return queue - # Get the updated SamplingParams from the request, which - # were cloned/updated in processor.process_inputs above. - parent_params = request.sampling_params - assert parent_params is not None + parent_params = params + assert isinstance(parent_params, SamplingParams) # Fan out child requests (for n>1). parent_request = ParentRequest(request_id, parent_params) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index e403cea87788..dffe05445ee4 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -250,6 +250,9 @@ class LLMEngine: elif isinstance(prompt, Mapping): prompt_text = cast(str | None, prompt.get("prompt")) + # Use cloned params that may have been updated in process_inputs() + params = request.params + n = params.n if isinstance(params, SamplingParams) else 1 if n == 1: @@ -262,10 +265,10 @@ class LLMEngine: # Fan out child requests (for n>1). parent_req = ParentRequest(request_id, params) for idx in range(n): - request_id, params = parent_req.get_child_info(idx) + request_id, child_params = parent_req.get_child_info(idx) child_request = request if idx == n - 1 else copy(request) child_request.request_id = request_id - child_request.sampling_params = params + child_request.sampling_params = child_params # Make a new RequestState and queue. self.output_processor.add_request(