mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:35:01 +08:00
[Bugfix]: fix issue with n>1 sampling on v1 requests overriding each other (#16863)
Signed-off-by: Jeffrey Li <jeffrey.dot.li@gmail.com>
This commit is contained in:
parent
1311913f55
commit
0e4254492f
@ -921,3 +921,84 @@ async def test_request_output_collector():
|
||||
# Cumulative logprobs should be the last one.
|
||||
cumulative_logprob_expected = 1.0 * num_to_put
|
||||
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cumulative_output_collector_n():
|
||||
"""Test collector correctly handles multiple outputs by index."""
|
||||
collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE)
|
||||
outputs = [
|
||||
RequestOutput(
|
||||
request_id="my-request-id",
|
||||
prompt=None,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
index=0,
|
||||
text="a",
|
||||
token_ids=[0],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
CompletionOutput(
|
||||
index=1,
|
||||
text="b",
|
||||
token_ids=[1],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
],
|
||||
finished=False,
|
||||
),
|
||||
RequestOutput(
|
||||
request_id="my-request-id",
|
||||
prompt=None,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
index=0,
|
||||
text="ab",
|
||||
token_ids=[0, 1],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
CompletionOutput(
|
||||
index=2,
|
||||
text="c",
|
||||
token_ids=[2],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
],
|
||||
finished=False,
|
||||
),
|
||||
]
|
||||
for output in outputs:
|
||||
collector.put(output)
|
||||
|
||||
# Get the output and check that the text and token_ids are correct.
|
||||
result = await collector.get()
|
||||
# We are expecting
|
||||
# [{index: 0, text: "ab"}, {index: 1, text: "b"}, {index: 2, text: "c"}]
|
||||
assert len(result.outputs) == 3
|
||||
# First is the one where index is 0
|
||||
first = [k for k in result.outputs if k.index == 0]
|
||||
assert len(first) == 1
|
||||
assert first[0].text == "ab"
|
||||
|
||||
# Second is the one where index is 1
|
||||
second = [k for k in result.outputs if k.index == 1]
|
||||
assert len(second) == 1
|
||||
assert second[0].text == "b"
|
||||
assert second[0].token_ids == [1]
|
||||
|
||||
# Third is the one where index is 2
|
||||
third = [k for k in result.outputs if k.index == 2]
|
||||
assert len(third) == 1
|
||||
assert third[0].text == "c"
|
||||
|
||||
@ -134,26 +134,32 @@ class RequestOutput:
|
||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||
self.num_cached_tokens = num_cached_tokens
|
||||
|
||||
def add(self, next_output: "RequestOutput") -> None:
|
||||
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
|
||||
"""Merge subsequent RequestOutput into this one"""
|
||||
|
||||
self.finished |= next_output.finished
|
||||
|
||||
for next_completion in next_output.outputs:
|
||||
for completion in self.outputs:
|
||||
for i, completion in enumerate(self.outputs):
|
||||
if completion.index == next_completion.index:
|
||||
# Merge outputs with same index
|
||||
completion.text += next_completion.text
|
||||
if not isinstance(completion.token_ids, MutableSequence):
|
||||
completion.token_ids = list(completion.token_ids)
|
||||
completion.token_ids.extend(next_completion.token_ids)
|
||||
if next_completion.logprobs:
|
||||
assert completion.logprobs is not None
|
||||
completion.logprobs.extend(next_completion.logprobs)
|
||||
completion.cumulative_logprob = (
|
||||
next_completion.cumulative_logprob)
|
||||
completion.finish_reason = next_completion.finish_reason
|
||||
completion.stop_reason = next_completion.stop_reason
|
||||
if aggregate:
|
||||
# Merge outputs with same index
|
||||
completion.text += next_completion.text
|
||||
if not isinstance(completion.token_ids,
|
||||
MutableSequence):
|
||||
completion.token_ids = list(completion.token_ids)
|
||||
completion.token_ids.extend(next_completion.token_ids)
|
||||
if next_completion.logprobs:
|
||||
assert completion.logprobs is not None
|
||||
completion.logprobs.extend(
|
||||
next_completion.logprobs)
|
||||
completion.cumulative_logprob = (
|
||||
next_completion.cumulative_logprob)
|
||||
completion.finish_reason = next_completion.finish_reason
|
||||
completion.stop_reason = next_completion.stop_reason
|
||||
else:
|
||||
# Replace the output with the new one
|
||||
self.outputs[i] = next_completion
|
||||
break
|
||||
else:
|
||||
self.outputs.append(next_completion)
|
||||
|
||||
@ -37,12 +37,9 @@ class RequestOutputCollector:
|
||||
self.output = output
|
||||
self.ready.set()
|
||||
elif isinstance(self.output, RequestOutput):
|
||||
if self.aggregate:
|
||||
# Coalesce the outputs in delta case.
|
||||
self.output.add(output)
|
||||
else:
|
||||
# Just replace latest in non-delta case.
|
||||
self.output = output
|
||||
# This ensures that request outputs with different request indexes
|
||||
# (if n > 1) do not override each other.
|
||||
self.output.add(output, aggregate=self.aggregate)
|
||||
|
||||
async def get(self) -> RequestOutput:
|
||||
"""Get operation blocks on put event."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user