mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 01:15:03 +08:00
[Bugfix] Fix for multinode crash on 4 PP (#6495)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
parent
5bf35a91e4
commit
5fa6e9876e
@ -4,14 +4,12 @@ from ..utils import RemoteOpenAIServer
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME",
|
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
|
||||||
[
|
|
||||||
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
|
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
|
||||||
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
|
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
|
||||||
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
|
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
|
||||||
# TODO: figure out why PP=4 tests are flaky
|
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
|
||||||
# (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
|
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
|
||||||
# (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
|
|
||||||
])
|
])
|
||||||
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
||||||
pp_args = [
|
pp_args = [
|
||||||
|
|||||||
@ -224,13 +224,27 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
# broadcasted to.
|
# broadcasted to.
|
||||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
self.non_driver_workers: List[RayWorkerWrapper] = []
|
||||||
|
|
||||||
|
tp_driver_worker_ranks = []
|
||||||
|
non_driver_worker_ranks = []
|
||||||
for idx, rank in enumerate(worker_ranks[1:]):
|
for idx, rank in enumerate(worker_ranks[1:]):
|
||||||
# We need to skip the driver worker, which we
|
# We need to skip the driver worker, which we
|
||||||
# do by skipping worker_ranks[0] which is always 0.
|
# do by skipping worker_ranks[0] which is always 0.
|
||||||
if rank % self.parallel_config.tensor_parallel_size == 0:
|
if rank % self.parallel_config.tensor_parallel_size == 0:
|
||||||
self.tp_driver_workers.append(self.workers[idx])
|
self.tp_driver_workers.append(self.workers[idx])
|
||||||
|
tp_driver_worker_ranks.append(rank)
|
||||||
else:
|
else:
|
||||||
self.non_driver_workers.append(self.workers[idx])
|
self.non_driver_workers.append(self.workers[idx])
|
||||||
|
non_driver_worker_ranks.append(rank)
|
||||||
|
|
||||||
|
# Enforce rank order for correct rank to return final output.
|
||||||
|
self.tp_driver_workers = [
|
||||||
|
worker for _, worker in sorted(
|
||||||
|
zip(tp_driver_worker_ranks, self.tp_driver_workers))
|
||||||
|
]
|
||||||
|
self.non_driver_workers = [
|
||||||
|
worker for _, worker in sorted(
|
||||||
|
zip(non_driver_worker_ranks, self.non_driver_workers))
|
||||||
|
]
|
||||||
|
|
||||||
def _driver_execute_model(
|
def _driver_execute_model(
|
||||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user