mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:25:01 +08:00
[BugFix] bad_words filtering ineffective when n > 1 (#29313)
Signed-off-by: GOavi101 <1704178@kiit.ac.in>
This commit is contained in:
parent
db2906108a
commit
32c40b95e0
@ -72,6 +72,14 @@ class EngineCoreRequest(
|
||||
|
||||
trace_headers: Mapping[str, str] | None = None
|
||||
|
||||
@property
|
||||
def params(self) -> SamplingParams | PoolingParams:
|
||||
"""Return the processed params (sampling or pooling)."""
|
||||
if self.sampling_params is not None:
|
||||
return self.sampling_params
|
||||
assert self.pooling_params is not None
|
||||
return self.pooling_params
|
||||
|
||||
|
||||
class EngineCoreEventType(enum.IntEnum):
|
||||
"""The type of engine core request event."""
|
||||
|
||||
@ -321,14 +321,15 @@ class AsyncLLM(EngineClient):
|
||||
elif isinstance(prompt, Mapping):
|
||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||
|
||||
# Use cloned params that may have been updated in process_inputs()
|
||||
params = request.params
|
||||
|
||||
if is_pooling or params.n == 1:
|
||||
await self._add_request(request, prompt_text, None, 0, queue)
|
||||
return queue
|
||||
|
||||
# Get the updated SamplingParams from the request, which
|
||||
# were cloned/updated in processor.process_inputs above.
|
||||
parent_params = request.sampling_params
|
||||
assert parent_params is not None
|
||||
parent_params = params
|
||||
assert isinstance(parent_params, SamplingParams)
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_request = ParentRequest(request_id, parent_params)
|
||||
|
||||
@ -250,6 +250,9 @@ class LLMEngine:
|
||||
elif isinstance(prompt, Mapping):
|
||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||
|
||||
# Use cloned params that may have been updated in process_inputs()
|
||||
params = request.params
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
if n == 1:
|
||||
@ -262,10 +265,10 @@ class LLMEngine:
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
request_id, child_params = parent_req.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
child_request.sampling_params = child_params
|
||||
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user