diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 9ac42dbc34a4..f8d96caf1a27 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -921,3 +921,84 @@ async def test_request_output_collector(): # Cumulative logprobs should be the last one. cumulative_logprob_expected = 1.0 * num_to_put assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected + + +@pytest.mark.asyncio +async def test_cumulative_output_collector_n(): + """Test collector correctly handles multiple outputs by index.""" + collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE) + outputs = [ + RequestOutput( + request_id="my-request-id", + prompt=None, + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text="a", + token_ids=[0], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + CompletionOutput( + index=1, + text="b", + token_ids=[1], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + ], + finished=False, + ), + RequestOutput( + request_id="my-request-id", + prompt=None, + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text="ab", + token_ids=[0, 1], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + CompletionOutput( + index=2, + text="c", + token_ids=[2], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + ], + finished=False, + ), + ] + for output in outputs: + collector.put(output) + + # Get the output and check that the text and token_ids are correct. + result = await collector.get() + # We are expecting + # [{index: 0, text: "ab"}, {index: 1, text: "b"}, {index: 2, text: "c"}] + assert len(result.outputs) == 3 + # First is the one where index is 0 + first = [k for k in result.outputs if k.index == 0] + assert len(first) == 1 + assert first[0].text == "ab" + + # Second is the one where index is 1 + second = [k for k in result.outputs if k.index == 1] + assert len(second) == 1 + assert second[0].text == "b" + assert second[0].token_ids == [1] + + # Third is the one where index is 2 + third = [k for k in result.outputs if k.index == 2] + assert len(third) == 1 + assert third[0].text == "c" diff --git a/vllm/outputs.py b/vllm/outputs.py index 014e8d5d8823..c8b9be5424e4 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -134,26 +134,32 @@ class RequestOutput: self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens - def add(self, next_output: "RequestOutput") -> None: + def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished for next_completion in next_output.outputs: - for completion in self.outputs: + for i, completion in enumerate(self.outputs): if completion.index == next_completion.index: - # Merge outputs with same index - completion.text += next_completion.text - if not isinstance(completion.token_ids, MutableSequence): - completion.token_ids = list(completion.token_ids) - completion.token_ids.extend(next_completion.token_ids) - if next_completion.logprobs: - assert completion.logprobs is not None - completion.logprobs.extend(next_completion.logprobs) - completion.cumulative_logprob = ( - next_completion.cumulative_logprob) - completion.finish_reason = next_completion.finish_reason - completion.stop_reason = next_completion.stop_reason + if aggregate: + # Merge outputs with same index + completion.text += next_completion.text + if not isinstance(completion.token_ids, + MutableSequence): + completion.token_ids = list(completion.token_ids) + completion.token_ids.extend(next_completion.token_ids) + if next_completion.logprobs: + assert completion.logprobs is not None + completion.logprobs.extend( + next_completion.logprobs) + completion.cumulative_logprob = ( + next_completion.cumulative_logprob) + completion.finish_reason = next_completion.finish_reason + completion.stop_reason = next_completion.stop_reason + else: + # Replace the output with the new one + self.outputs[i] = next_completion break else: self.outputs.append(next_completion) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 21e2a1aee4e2..d652b17e55b3 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -37,12 +37,9 @@ class RequestOutputCollector: self.output = output self.ready.set() elif isinstance(self.output, RequestOutput): - if self.aggregate: - # Coalesce the outputs in delta case. - self.output.add(output) - else: - # Just replace latest in non-delta case. - self.output = output + # This ensures that request outputs with different request indexes + # (if n > 1) do not override each other. + self.output.add(output, aggregate=self.aggregate) async def get(self) -> RequestOutput: """Get operation blocks on put event."""