mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 21:55:42 +08:00
[BugFix] bad_words filtering ineffective when n > 1 (#29313)
Signed-off-by: GOavi101 <1704178@kiit.ac.in>
This commit is contained in:
parent
db2906108a
commit
32c40b95e0
@ -72,6 +72,14 @@ class EngineCoreRequest(
|
|||||||
|
|
||||||
trace_headers: Mapping[str, str] | None = None
|
trace_headers: Mapping[str, str] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def params(self) -> SamplingParams | PoolingParams:
|
||||||
|
"""Return the processed params (sampling or pooling)."""
|
||||||
|
if self.sampling_params is not None:
|
||||||
|
return self.sampling_params
|
||||||
|
assert self.pooling_params is not None
|
||||||
|
return self.pooling_params
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreEventType(enum.IntEnum):
|
class EngineCoreEventType(enum.IntEnum):
|
||||||
"""The type of engine core request event."""
|
"""The type of engine core request event."""
|
||||||
|
|||||||
@ -321,14 +321,15 @@ class AsyncLLM(EngineClient):
|
|||||||
elif isinstance(prompt, Mapping):
|
elif isinstance(prompt, Mapping):
|
||||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||||
|
|
||||||
|
# Use cloned params that may have been updated in process_inputs()
|
||||||
|
params = request.params
|
||||||
|
|
||||||
if is_pooling or params.n == 1:
|
if is_pooling or params.n == 1:
|
||||||
await self._add_request(request, prompt_text, None, 0, queue)
|
await self._add_request(request, prompt_text, None, 0, queue)
|
||||||
return queue
|
return queue
|
||||||
|
|
||||||
# Get the updated SamplingParams from the request, which
|
parent_params = params
|
||||||
# were cloned/updated in processor.process_inputs above.
|
assert isinstance(parent_params, SamplingParams)
|
||||||
parent_params = request.sampling_params
|
|
||||||
assert parent_params is not None
|
|
||||||
|
|
||||||
# Fan out child requests (for n>1).
|
# Fan out child requests (for n>1).
|
||||||
parent_request = ParentRequest(request_id, parent_params)
|
parent_request = ParentRequest(request_id, parent_params)
|
||||||
|
|||||||
@ -250,6 +250,9 @@ class LLMEngine:
|
|||||||
elif isinstance(prompt, Mapping):
|
elif isinstance(prompt, Mapping):
|
||||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||||
|
|
||||||
|
# Use cloned params that may have been updated in process_inputs()
|
||||||
|
params = request.params
|
||||||
|
|
||||||
n = params.n if isinstance(params, SamplingParams) else 1
|
n = params.n if isinstance(params, SamplingParams) else 1
|
||||||
|
|
||||||
if n == 1:
|
if n == 1:
|
||||||
@ -262,10 +265,10 @@ class LLMEngine:
|
|||||||
# Fan out child requests (for n>1).
|
# Fan out child requests (for n>1).
|
||||||
parent_req = ParentRequest(request_id, params)
|
parent_req = ParentRequest(request_id, params)
|
||||||
for idx in range(n):
|
for idx in range(n):
|
||||||
request_id, params = parent_req.get_child_info(idx)
|
request_id, child_params = parent_req.get_child_info(idx)
|
||||||
child_request = request if idx == n - 1 else copy(request)
|
child_request = request if idx == n - 1 else copy(request)
|
||||||
child_request.request_id = request_id
|
child_request.request_id = request_id
|
||||||
child_request.sampling_params = params
|
child_request.sampling_params = child_params
|
||||||
|
|
||||||
# Make a new RequestState and queue.
|
# Make a new RequestState and queue.
|
||||||
self.output_processor.add_request(
|
self.output_processor.add_request(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user