[BugFix] bad_words filtering ineffective when n > 1 (#29313)

Signed-off-by: GOavi101 <1704178@kiit.ac.in>
This commit is contained in:
Avishek Goswami 2025-11-25 15:06:34 +05:30 committed by GitHub
parent db2906108a
commit 32c40b95e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 6 deletions

View File

@ -72,6 +72,14 @@ class EngineCoreRequest(
trace_headers: Mapping[str, str] | None = None
@property
def params(self) -> SamplingParams | PoolingParams:
"""Return the processed params (sampling or pooling)."""
if self.sampling_params is not None:
return self.sampling_params
assert self.pooling_params is not None
return self.pooling_params
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""

View File

@ -321,14 +321,15 @@ class AsyncLLM(EngineClient):
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
# Use cloned params that may have been updated in process_inputs()
params = request.params
if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue)
return queue
# Get the updated SamplingParams from the request, which
# were cloned/updated in processor.process_inputs above.
parent_params = request.sampling_params
assert parent_params is not None
parent_params = params
assert isinstance(parent_params, SamplingParams)
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, parent_params)

View File

@ -250,6 +250,9 @@ class LLMEngine:
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
# Use cloned params that may have been updated in process_inputs()
params = request.params
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
@ -262,10 +265,10 @@ class LLMEngine:
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
request_id, child_params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
child_request.sampling_params = child_params
# Make a new RequestState and queue.
self.output_processor.add_request(