[Bugfix] Fix for multinode crash on 4 PP (#6495)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Murali Andoorveedu 2024-07-17 04:25:10 -04:00 committed by GitHub
parent 5bf35a91e4
commit 5fa6e9876e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 5 deletions

View File

@ -4,14 +4,12 @@ from ..utils import RemoteOpenAIServer
@pytest.mark.parametrize(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME",
[
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
# TODO: figure out why PP=4 tests are flaky
# (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
# (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
pp_args = [

View File

@ -224,13 +224,27 @@ class RayGPUExecutor(DistributedGPUExecutor):
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
tp_driver_worker_ranks = []
non_driver_worker_ranks = []
for idx, rank in enumerate(worker_ranks[1:]):
# We need to skip the driver worker, which we
# do by skipping worker_ranks[0] which is always 0.
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(self.workers[idx])
tp_driver_worker_ranks.append(rank)
else:
self.non_driver_workers.append(self.workers[idx])
non_driver_worker_ranks.append(rank)
# Enforce rank order for correct rank to return final output.
self.tp_driver_workers = [
worker for _, worker in sorted(
zip(tp_driver_worker_ranks, self.tp_driver_workers))
]
self.non_driver_workers = [
worker for _, worker in sorted(
zip(non_driver_worker_ranks, self.non_driver_workers))
]
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]