mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:25:44 +08:00
[XPU] support data parallel for MoE models on XPU (#22887)
Signed-off-by: chzhang <chaojun.zhang@intel.com>
This commit is contained in:
parent
b668055a11
commit
235c9db8a7
@ -7,8 +7,13 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .base_device_communicator import DeviceCommunicatorBase
|
from .base_device_communicator import DeviceCommunicatorBase
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class XpuCommunicator(DeviceCommunicatorBase):
|
class XpuCommunicator(DeviceCommunicatorBase):
|
||||||
|
|
||||||
@ -18,6 +23,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
|||||||
device_group: Optional[ProcessGroup] = None,
|
device_group: Optional[ProcessGroup] = None,
|
||||||
unique_name: str = ""):
|
unique_name: str = ""):
|
||||||
super().__init__(cpu_group, device, device_group, unique_name)
|
super().__init__(cpu_group, device, device_group, unique_name)
|
||||||
|
if self.use_all2all:
|
||||||
|
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||||
|
if all2all_backend == "naive":
|
||||||
|
from .all2all import NaiveAll2AllManager
|
||||||
|
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||||
|
logger.info("Using naive all2all manager.")
|
||||||
|
|
||||||
def all_reduce(self, input_) -> torch.Tensor:
|
def all_reduce(self, input_) -> torch.Tensor:
|
||||||
dist.all_reduce(input_, group=self.device_group)
|
dist.all_reduce(input_, group=self.device_group)
|
||||||
|
|||||||
@ -655,6 +655,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
forward_native = forward_tpu
|
forward_native = forward_tpu
|
||||||
elif current_platform.is_cpu():
|
elif current_platform.is_cpu():
|
||||||
forward_native = forward_cpu
|
forward_native = forward_cpu
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
forward_native = forward_xpu
|
||||||
else:
|
else:
|
||||||
forward_native = forward_cuda
|
forward_native = forward_cuda
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user