[XPU] Make pp group initilized for pipeline-parallelism (#11648)

Signed-off-by: yisheng <yi.sheng@intel.com>
This commit is contained in:
YiSheng5 2025-01-07 11:09:58 +08:00 committed by GitHub
parent d0169e1b0f
commit d93d2d74fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,6 +11,7 @@ import torch.distributed
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
@ -176,3 +177,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
parallel_config.pipeline_parallel_size)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())
if parallel_config.pipeline_parallel_size > 1:
# Add pp group init to avoid
# p2p communication as the first call
get_pp_group().all_reduce(torch.zeros(1).xpu())