[Bugfix]: fix issue with n>1 sampling on v1 requests overriding each other (#16863)

Signed-off-by: Jeffrey Li <jeffrey.dot.li@gmail.com>
This commit is contained in:
Jeffrey Li 2025-04-21 23:40:19 -04:00 committed by GitHub
parent 1311913f55
commit 0e4254492f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 104 additions and 20 deletions

View File

@ -921,3 +921,84 @@ async def test_request_output_collector():
# Cumulative logprobs should be the last one.
cumulative_logprob_expected = 1.0 * num_to_put
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
@pytest.mark.asyncio
async def test_cumulative_output_collector_n():
"""Test collector correctly handles multiple outputs by index."""
collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE)
outputs = [
RequestOutput(
request_id="my-request-id",
prompt=None,
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[
CompletionOutput(
index=0,
text="a",
token_ids=[0],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
),
CompletionOutput(
index=1,
text="b",
token_ids=[1],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
),
],
finished=False,
),
RequestOutput(
request_id="my-request-id",
prompt=None,
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[
CompletionOutput(
index=0,
text="ab",
token_ids=[0, 1],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
),
CompletionOutput(
index=2,
text="c",
token_ids=[2],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
),
],
finished=False,
),
]
for output in outputs:
collector.put(output)
# Get the output and check that the text and token_ids are correct.
result = await collector.get()
# We are expecting
# [{index: 0, text: "ab"}, {index: 1, text: "b"}, {index: 2, text: "c"}]
assert len(result.outputs) == 3
# First is the one where index is 0
first = [k for k in result.outputs if k.index == 0]
assert len(first) == 1
assert first[0].text == "ab"
# Second is the one where index is 1
second = [k for k in result.outputs if k.index == 1]
assert len(second) == 1
assert second[0].text == "b"
assert second[0].token_ids == [1]
# Third is the one where index is 2
third = [k for k in result.outputs if k.index == 2]
assert len(third) == 1
assert third[0].text == "c"

View File

@ -134,26 +134,32 @@ class RequestOutput:
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
def add(self, next_output: "RequestOutput") -> None:
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one"""
self.finished |= next_output.finished
for next_completion in next_output.outputs:
for completion in self.outputs:
for i, completion in enumerate(self.outputs):
if completion.index == next_completion.index:
# Merge outputs with same index
completion.text += next_completion.text
if not isinstance(completion.token_ids, MutableSequence):
completion.token_ids = list(completion.token_ids)
completion.token_ids.extend(next_completion.token_ids)
if next_completion.logprobs:
assert completion.logprobs is not None
completion.logprobs.extend(next_completion.logprobs)
completion.cumulative_logprob = (
next_completion.cumulative_logprob)
completion.finish_reason = next_completion.finish_reason
completion.stop_reason = next_completion.stop_reason
if aggregate:
# Merge outputs with same index
completion.text += next_completion.text
if not isinstance(completion.token_ids,
MutableSequence):
completion.token_ids = list(completion.token_ids)
completion.token_ids.extend(next_completion.token_ids)
if next_completion.logprobs:
assert completion.logprobs is not None
completion.logprobs.extend(
next_completion.logprobs)
completion.cumulative_logprob = (
next_completion.cumulative_logprob)
completion.finish_reason = next_completion.finish_reason
completion.stop_reason = next_completion.stop_reason
else:
# Replace the output with the new one
self.outputs[i] = next_completion
break
else:
self.outputs.append(next_completion)

View File

@ -37,12 +37,9 @@ class RequestOutputCollector:
self.output = output
self.ready.set()
elif isinstance(self.output, RequestOutput):
if self.aggregate:
# Coalesce the outputs in delta case.
self.output.add(output)
else:
# Just replace latest in non-delta case.
self.output = output
# This ensures that request outputs with different request indexes
# (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate)
async def get(self) -> RequestOutput:
"""Get operation blocks on put event."""