[Neuron] Add Neuron device communicator for vLLM v1 (#14085)

This commit is contained in:
gnovack 2025-03-10 18:37:04 -07:00 committed by GitHub
parent 485afdd3cb
commit d6123170d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 127 additions and 0 deletions

View File

@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
import functools
from typing import Callable
from unittest.mock import patch
import pytest
import torch
import torch_xla.distributed.xla_multiprocessing as xmp
from typing_extensions import ParamSpec
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import get_distributed_init_method, get_open_port
_P = ParamSpec("_P")
def reinitialize_neuron_runtime(f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to reinitialize the Neuron Runtime before executing a test.
This is necessary for distributed tests which need to reallocate Neuron
Cores to separate subprocesses.
"""
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
runtime = torch.classes.neuron.Runtime()
runtime.initialize()
runtime.unsafe_close()
f(*args, **kwargs)
runtime.initialize()
return wrapper
def all_gather_test_worker(index, tp_degree, distributed_init_method):
init_distributed_environment(tp_degree,
index,
distributed_init_method,
index,
backend="xla")
ensure_model_parallel_initialized(tp_degree, 1)
num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2))
total_size = 1
for s in tensor_size:
total_size *= s
all_gather_dimension = -1
all_tensors = [
torch.arange(total_size, dtype=torch.float32,
device="xla").reshape(tensor_size) * (r + 1)
for r in range(tp_degree)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[index % tp_degree]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
torch.testing.assert_close(t, expected)
def all_reduce_test_worker(index, tp_degree, distributed_init_method):
init_distributed_environment(tp_degree,
index,
distributed_init_method,
index,
backend="xla")
ensure_model_parallel_initialized(tp_degree, 1)
num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="xla") * (r + 1)
for r in range(tp_degree)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[index % tp_degree]
t = tensor_model_parallel_all_reduce(t)
torch.testing.assert_close(t, expected)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target",
[all_reduce_test_worker, all_gather_test_worker])
@reinitialize_neuron_runtime
def test_neuron_multi_process_tensor_parallel(monkeypatch, tp_size,
test_target):
with patch('torch_xla._XLAC._xla_runtime_is_initialized',
return_value=False):
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())
monkeypatch.setenv("VLLM_USE_V1", "1")
monkeypatch.setenv("NEURONCORE_NUM_DEVICES", str(tp_size))
monkeypatch.setenv("NEURON_PJRT_PROCESSES_NUM_DEVICES",
','.join(['1' for _ in range(tp_size)]))
xmp.spawn(test_target, args=(tp_size, distributed_init_method))

View File

@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.platforms import current_platform
if current_platform.is_neuron():
import torch_xla.core.xla_model as xm
class NeuronCommunicator(DeviceCommunicatorBase):
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)

View File

@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
@ -56,6 +57,13 @@ class NeuronPlatform(Platform):
logger.warning("Pin memory is not supported on Neuron.")
return False
@classmethod
def get_device_communicator_cls(cls) -> str:
if envs.VLLM_USE_V1:
return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa
else:
return Platform.get_device_communicator_cls()
@classmethod
def use_all_gather(cls) -> bool:
return True