[V1] [KVConnector] Fix MultiprocExecutor worker output aggregation (#21048)

Signed-off-by: David Ben-David <davidb@pliops.com>
Co-authored-by: David Ben-David <davidb@pliops.com>
This commit is contained in:
David Ben-David 2025-07-17 08:29:45 +03:00 committed by GitHub
parent 8a4e5c5f3c
commit 4fcef49ec4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 129 additions and 4 deletions

View File

@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections import defaultdict
from concurrent.futures import Future
from typing import Optional
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
from vllm.v1.outputs import ModelRunnerOutput
class DummyMultiprocExecutor(MultiprocExecutor):
def __init__(self, output_rank, world_size):
# Manually initialize minimal required fields
self.output_rank = output_rank
self.world_size = world_size
self._send_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self._recv_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self.io_thread_pool = None
self.shutdown_event = threading.Event()
class DummyModelRunnerOutput(ModelRunnerOutput):
def __init__(self,
finished_sending: Optional[set[str]] = None,
finished_recving: Optional[set[str]] = None):
self.finished_sending = finished_sending
self.finished_recving = finished_recving
def test_aggregate_workers_output():
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output2 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
aggregated = executor._aggregate_workers_output([output1, output2])
assert aggregated is output1
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving=None)
aggregated = executor._aggregate_workers_output([output1, output2])
assert aggregated is output1
assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_recving is None
output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
aggregated = executor._aggregate_workers_output([output1, output2])
assert aggregated is output1
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'}
def test_async_aggregate_workers_output():
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
future1: Future[DummyModelRunnerOutput] = Future()
future2: Future[DummyModelRunnerOutput] = Future()
result_future = executor._async_aggregate_workers_output(
[future1, future2])
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output2 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
future1.set_result(output1)
future2.set_result(output2)
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
future1 = Future()
future2 = Future()
result_future = executor._async_aggregate_workers_output(
[future1, future2])
output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving=None)
future1.set_result(output1)
future2.set_result(output2)
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_recving is None
future1 = Future()
future2 = Future()
result_future = executor._async_aggregate_workers_output(
[future1, future2])
output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
future1.set_result(output1)
future2.set_result(output2)
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'}

View File

@ -273,10 +273,8 @@ class MultiprocExecutor(Executor):
output = outputs[self.output_rank]
# set the aggregated finished_sending / finished_recving
if finished_sending:
output.finished_sending = finished_sending
if finished_recving:
output.finished_recving = finished_recving
output.finished_sending = finished_sending if finished_sending else None
output.finished_recving = finished_recving if finished_recving else None
return output