mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Bugfix]: fix issue with n>1 sampling on v1 requests overriding each other (#16863)
Signed-off-by: Jeffrey Li <jeffrey.dot.li@gmail.com>
This commit is contained in:
parent
1311913f55
commit
0e4254492f
@ -921,3 +921,84 @@ async def test_request_output_collector():
|
|||||||
# Cumulative logprobs should be the last one.
|
# Cumulative logprobs should be the last one.
|
||||||
cumulative_logprob_expected = 1.0 * num_to_put
|
cumulative_logprob_expected = 1.0 * num_to_put
|
||||||
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
|
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cumulative_output_collector_n():
|
||||||
|
"""Test collector correctly handles multiple outputs by index."""
|
||||||
|
collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE)
|
||||||
|
outputs = [
|
||||||
|
RequestOutput(
|
||||||
|
request_id="my-request-id",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[
|
||||||
|
CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text="a",
|
||||||
|
token_ids=[0],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
),
|
||||||
|
CompletionOutput(
|
||||||
|
index=1,
|
||||||
|
text="b",
|
||||||
|
token_ids=[1],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
finished=False,
|
||||||
|
),
|
||||||
|
RequestOutput(
|
||||||
|
request_id="my-request-id",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[
|
||||||
|
CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text="ab",
|
||||||
|
token_ids=[0, 1],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
),
|
||||||
|
CompletionOutput(
|
||||||
|
index=2,
|
||||||
|
text="c",
|
||||||
|
token_ids=[2],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
finished=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for output in outputs:
|
||||||
|
collector.put(output)
|
||||||
|
|
||||||
|
# Get the output and check that the text and token_ids are correct.
|
||||||
|
result = await collector.get()
|
||||||
|
# We are expecting
|
||||||
|
# [{index: 0, text: "ab"}, {index: 1, text: "b"}, {index: 2, text: "c"}]
|
||||||
|
assert len(result.outputs) == 3
|
||||||
|
# First is the one where index is 0
|
||||||
|
first = [k for k in result.outputs if k.index == 0]
|
||||||
|
assert len(first) == 1
|
||||||
|
assert first[0].text == "ab"
|
||||||
|
|
||||||
|
# Second is the one where index is 1
|
||||||
|
second = [k for k in result.outputs if k.index == 1]
|
||||||
|
assert len(second) == 1
|
||||||
|
assert second[0].text == "b"
|
||||||
|
assert second[0].token_ids == [1]
|
||||||
|
|
||||||
|
# Third is the one where index is 2
|
||||||
|
third = [k for k in result.outputs if k.index == 2]
|
||||||
|
assert len(third) == 1
|
||||||
|
assert third[0].text == "c"
|
||||||
|
|||||||
@ -134,26 +134,32 @@ class RequestOutput:
|
|||||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||||
self.num_cached_tokens = num_cached_tokens
|
self.num_cached_tokens = num_cached_tokens
|
||||||
|
|
||||||
def add(self, next_output: "RequestOutput") -> None:
|
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
|
||||||
"""Merge subsequent RequestOutput into this one"""
|
"""Merge subsequent RequestOutput into this one"""
|
||||||
|
|
||||||
self.finished |= next_output.finished
|
self.finished |= next_output.finished
|
||||||
|
|
||||||
for next_completion in next_output.outputs:
|
for next_completion in next_output.outputs:
|
||||||
for completion in self.outputs:
|
for i, completion in enumerate(self.outputs):
|
||||||
if completion.index == next_completion.index:
|
if completion.index == next_completion.index:
|
||||||
# Merge outputs with same index
|
if aggregate:
|
||||||
completion.text += next_completion.text
|
# Merge outputs with same index
|
||||||
if not isinstance(completion.token_ids, MutableSequence):
|
completion.text += next_completion.text
|
||||||
completion.token_ids = list(completion.token_ids)
|
if not isinstance(completion.token_ids,
|
||||||
completion.token_ids.extend(next_completion.token_ids)
|
MutableSequence):
|
||||||
if next_completion.logprobs:
|
completion.token_ids = list(completion.token_ids)
|
||||||
assert completion.logprobs is not None
|
completion.token_ids.extend(next_completion.token_ids)
|
||||||
completion.logprobs.extend(next_completion.logprobs)
|
if next_completion.logprobs:
|
||||||
completion.cumulative_logprob = (
|
assert completion.logprobs is not None
|
||||||
next_completion.cumulative_logprob)
|
completion.logprobs.extend(
|
||||||
completion.finish_reason = next_completion.finish_reason
|
next_completion.logprobs)
|
||||||
completion.stop_reason = next_completion.stop_reason
|
completion.cumulative_logprob = (
|
||||||
|
next_completion.cumulative_logprob)
|
||||||
|
completion.finish_reason = next_completion.finish_reason
|
||||||
|
completion.stop_reason = next_completion.stop_reason
|
||||||
|
else:
|
||||||
|
# Replace the output with the new one
|
||||||
|
self.outputs[i] = next_completion
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.outputs.append(next_completion)
|
self.outputs.append(next_completion)
|
||||||
|
|||||||
@ -37,12 +37,9 @@ class RequestOutputCollector:
|
|||||||
self.output = output
|
self.output = output
|
||||||
self.ready.set()
|
self.ready.set()
|
||||||
elif isinstance(self.output, RequestOutput):
|
elif isinstance(self.output, RequestOutput):
|
||||||
if self.aggregate:
|
# This ensures that request outputs with different request indexes
|
||||||
# Coalesce the outputs in delta case.
|
# (if n > 1) do not override each other.
|
||||||
self.output.add(output)
|
self.output.add(output, aggregate=self.aggregate)
|
||||||
else:
|
|
||||||
# Just replace latest in non-delta case.
|
|
||||||
self.output = output
|
|
||||||
|
|
||||||
async def get(self) -> RequestOutput:
|
async def get(self) -> RequestOutput:
|
||||||
"""Get operation blocks on put event."""
|
"""Get operation blocks on put event."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user