mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[V1] [KVConnector] Fix MultiprocExecutor worker output aggregation (#21048)
Signed-off-by: David Ben-David <davidb@pliops.com> Co-authored-by: David Ben-David <davidb@pliops.com>
This commit is contained in:
parent
8a4e5c5f3c
commit
4fcef49ec4
127
tests/v1/executor/test_multiproc_executor.py
Normal file
127
tests/v1/executor/test_multiproc_executor.py
Normal file
@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Optional
|
||||
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
class DummyMultiprocExecutor(MultiprocExecutor):
|
||||
|
||||
def __init__(self, output_rank, world_size):
|
||||
# Manually initialize minimal required fields
|
||||
self.output_rank = output_rank
|
||||
self.world_size = world_size
|
||||
self._send_remaining_count = defaultdict[str,
|
||||
int](lambda: self.world_size)
|
||||
self._recv_remaining_count = defaultdict[str,
|
||||
int](lambda: self.world_size)
|
||||
self.io_thread_pool = None
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
|
||||
class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
|
||||
def __init__(self,
|
||||
finished_sending: Optional[set[str]] = None,
|
||||
finished_recving: Optional[set[str]] = None):
|
||||
self.finished_sending = finished_sending
|
||||
self.finished_recving = finished_recving
|
||||
|
||||
|
||||
def test_aggregate_workers_output():
|
||||
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output2 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
|
||||
aggregated = executor._aggregate_workers_output([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving=None)
|
||||
|
||||
aggregated = executor._aggregate_workers_output([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
assert aggregated.finished_sending == {'req1'}
|
||||
assert aggregated.finished_recving is None
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
|
||||
aggregated = executor._aggregate_workers_output([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {'req2'}
|
||||
|
||||
|
||||
def test_async_aggregate_workers_output():
|
||||
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
|
||||
|
||||
future1: Future[DummyModelRunnerOutput] = Future()
|
||||
future2: Future[DummyModelRunnerOutput] = Future()
|
||||
result_future = executor._async_aggregate_workers_output(
|
||||
[future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output2 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
future1.set_result(output1)
|
||||
future2.set_result(output2)
|
||||
|
||||
assert result_future.done()
|
||||
aggregated = result_future.result()
|
||||
assert aggregated is output1
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
|
||||
future1 = Future()
|
||||
future2 = Future()
|
||||
result_future = executor._async_aggregate_workers_output(
|
||||
[future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving=None)
|
||||
future1.set_result(output1)
|
||||
future2.set_result(output2)
|
||||
|
||||
assert result_future.done()
|
||||
aggregated = result_future.result()
|
||||
assert aggregated is output1
|
||||
assert aggregated.finished_sending == {'req1'}
|
||||
assert aggregated.finished_recving is None
|
||||
|
||||
future1 = Future()
|
||||
future2 = Future()
|
||||
result_future = executor._async_aggregate_workers_output(
|
||||
[future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
future1.set_result(output1)
|
||||
future2.set_result(output2)
|
||||
|
||||
assert result_future.done()
|
||||
aggregated = result_future.result()
|
||||
assert aggregated is output1
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {'req2'}
|
||||
@ -273,10 +273,8 @@ class MultiprocExecutor(Executor):
|
||||
output = outputs[self.output_rank]
|
||||
|
||||
# set the aggregated finished_sending / finished_recving
|
||||
if finished_sending:
|
||||
output.finished_sending = finished_sending
|
||||
if finished_recving:
|
||||
output.finished_recving = finished_recving
|
||||
output.finished_sending = finished_sending if finished_sending else None
|
||||
output.finished_recving = finished_recving if finished_recving else None
|
||||
|
||||
return output
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user