[Hardware][TPU] Initial support of model parallelism with single worker using SPMD (#18011)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: Hossein Sarshar <hossein.sarshar@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Siyuan Liu 2025-06-02 17:06:20 -07:00 committed by GitHub
parent c57d577e8d
commit 9112b443a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 605 additions and 72 deletions

View File

@ -155,6 +155,10 @@ run_and_track_test 12 "test_moe_pallas.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" "python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
run_and_track_test 13 "test_lora.py" \ run_and_track_test 13 "test_lora.py" \
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py" "VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
run_and_track_test 14 "test_tpu_qkv_linear.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
# After all tests have been attempted, exit with the overall status. # After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then if [ "$overall_script_exit_code" -ne 0 ]; then

View File

@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse
import os
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
prompts = [ prompts = [
@ -18,14 +21,28 @@ sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)
def main(): def main():
parser = argparse.ArgumentParser(description="TPU offline inference example")
parser.add_argument("--use-spmd", action="store_true", help="Enable SPMD mode")
args = parser.parse_args()
llm_args = {
"model": "Qwen/Qwen2-1.5B-Instruct",
"max_num_batched_tokens": 64,
"max_num_seqs": 4,
"max_model_len": 128,
}
if args.use_spmd:
os.environ["VLLM_XLA_USE_SPMD"] = "1"
# Can only hardcode the number of chips for now.
# calling xr.global_runtime_device_count() beforeing init SPMD env in
# torch_xla will mess up the distributed env.
llm_args["tensor_parallel_size"] = 8
# Use Llama, for num_kv_heads = 8.
llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct"
# Set `enforce_eager=True` to avoid ahead-of-time compilation. # Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`. # In real workloads, `enforace_eager` should be `False`.
llm = LLM( llm = LLM(**llm_args)
model="Qwen/Qwen2-1.5B-Instruct",
max_num_batched_tokens=64,
max_num_seqs=4,
max_model_len=128,
)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
print("-" * 50) print("-" * 50)
for output, answer in zip(outputs, answers): for output, answer in zip(outputs, answers):

View File

@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
import gc
import tempfile
import numpy as np
import pytest
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tpu import TPUModelLoader
def _setup_environment(model):
engine_args = EngineArgs(model=model, )
vllm_config = engine_args.create_engine_config()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
1,
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo")
# Under single worker mode, full model is init first and then
# partitioned using GSPMD.
ensure_model_parallel_initialized(1, 1)
return vllm_config
MESH = None
def _get_spmd_mesh():
global MESH
if MESH is None:
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
return MESH
@pytest.mark.parametrize("model", [
"Qwen/Qwen2-1.5B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct",
])
def test_tpu_model_loader(model):
# Skip the 70B test if there are less than 8 chips
# TODO: Query using torch xla API, the query API is not working
# with SPMD now. However, This test is running under SPMD mode.
if '70B' in model and xr.global_runtime_device_count() < 8:
pytest.skip(
"Skipping 70B model if the TPU VM has less than 8 chips to \
avoid OOM.")
vllm_config = _setup_environment(model)
loader = TPUModelLoader(load_config=vllm_config.load_config)
mesh = _get_spmd_mesh()
model = loader.load_model(vllm_config, vllm_config.model_config, mesh)
del model
gc.collect()

View File

@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
import tempfile
import numpy as np
import pytest
import torch
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.linear import QKVParallelLinear
@pytest.fixture(autouse=True)
def setup_environment():
# This is a fake config used for init dist env.
# QKVParallelLinear needs dist env to be initialized.
engine_args = EngineArgs(
model="Qwen/Qwen2-1.5B-Instruct",
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
1,
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo")
ensure_model_parallel_initialized(1, 1)
yield
MESH = None
def _get_spmd_mesh():
global MESH
if MESH is None:
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
return MESH
@pytest.mark.parametrize("bias", [False, True])
# `xr.use_spmd()` will set a global state, and this state is not reversible.
# Therefore, non-SPMD tests should be run before SPMD tests.
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
@pytest.mark.parametrize("device", ['cpu', 'xla'])
@torch.no_grad()
def test_xla_qkv_linear(bias, mesh, device):
torch.manual_seed(123)
qkv_linear = QKVParallelLinear(
hidden_size=4096,
head_size=128,
total_num_heads=32,
total_num_kv_heads=8,
bias=bias,
params_dtype=torch.bfloat16,
return_bias=False,
)
qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10
if bias:
qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data)
xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh)
qkv_linear = qkv_linear.to(device)
xla_qkv_linear = xla_qkv_linear.to(device)
input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10
input_tensor = input_tensor.to(device)
output = qkv_linear(input_tensor)
xla_output = xla_qkv_linear(input_tensor)
assert torch.allclose(output.cpu(), xla_output.cpu())

View File

@ -1901,6 +1901,8 @@ class ParallelConfig:
if current_platform.is_neuron(): if current_platform.is_neuron():
# neuron uses single process to control multiple devices # neuron uses single process to control multiple devices
backend = "uni" backend = "uni"
elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
backend = "uni"
elif (current_platform.is_cuda() elif (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size): and cuda_device_count_stateless() < self.world_size):
if not ray_found: if not ray_found:

View File

@ -0,0 +1,177 @@
# SPDX-License-Identifier: Apache-2.0
from collections import OrderedDict
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.distributed.spmd as xs
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
logger = init_logger(__name__)
class XlaQKVParallelLinear(nn.Module):
def __init__(self,
qkv_linear: nn.Module,
mesh: Optional["xs.Mesh"] = None):
super().__init__()
assert isinstance(qkv_linear, QKVParallelLinear)
self.skip_bias_add = qkv_linear.skip_bias_add
self.return_bias = qkv_linear.return_bias
assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."
self.q_weight: Parameter
self.k_weight: Parameter
self.v_weight: Parameter
self.q_bias: Optional[Parameter]
self.k_bias: Optional[Parameter]
self.v_bias: Optional[Parameter]
self._load_weights_from_qkv_linear(qkv_linear)
if mesh is not None:
self._shard_weight(mesh)
def _shard_weight(self, mesh: "xs.Mesh"):
self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False)
self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False)
self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False)
xs.mark_sharding(self.q_weight, mesh, ('x', None))
xs.mark_sharding(self.k_weight, mesh, ('x', None))
xs.mark_sharding(self.v_weight, mesh, ('x', None))
if self.q_bias is not None:
assert self.k_bias is not None and self.v_bias is not None, \
"QKVParallelLinear should have q, k, and v biases together."
self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False)
xs.mark_sharding(self.q_bias, mesh, ('x', ))
self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False)
xs.mark_sharding(self.k_bias, mesh, ('x', ))
self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False)
xs.mark_sharding(self.v_bias, mesh, ('x', ))
def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
# The weight of qkv linear is a concatenation of q, k, and v weights
# along the output dimension.
qkv_weight = qkv_linear.weight.data.cpu()
q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size],
requires_grad=False)
v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:],
requires_grad=False)
self.register_parameter("q_weight", q_weight)
self.register_parameter("k_weight", k_weight)
self.register_parameter("v_weight", v_weight)
if qkv_linear.bias is not None:
q_bias = Parameter(qkv_linear.bias[:q_proj_size],
requires_grad=False)
k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size +
k_proj_size],
requires_grad=False)
v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:],
requires_grad=False)
self.register_parameter("q_bias", q_bias)
self.register_parameter("k_bias", k_bias)
self.register_parameter("v_bias", v_bias)
else:
self.register_parameter("q_bias", None)
self.register_parameter("k_bias", None)
self.register_parameter("v_bias", None)
def forward(self, input):
# Same forward functionality as QKVParallelLinear, but doing qkv porj
# separately.
q_bias = self.q_bias if not self.skip_bias_add else None
k_bias = self.k_bias if not self.skip_bias_add else None
v_bias = self.v_bias if not self.skip_bias_add else None
q_proj = F.linear(input, self.q_weight, q_bias)
k_proj = F.linear(input, self.k_weight, k_bias)
v_proj = F.linear(input, self.v_weight, v_bias)
# The q/k/v projections will be split outside of the QKVParallelLinear.
# Because we are replacing XlaQKVParallelLinear with the
# QKVParallelLinear, we need to concatenate q, k, and v projections to
# match the output shape of the QKVParallelLinear implementation even if
# it seems to be redundant.
# The concat and the following split will be noop, and should be
# optimized away by the compiler.
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \
self.skip_bias_add else None
if not self.return_bias:
return qkv_proj
return qkv_proj, output_bias
def partition_column_parallel_linear(layer: torch.nn.Module,
mesh: xs.Mesh) -> torch.nn.Module:
assert isinstance(layer, ColumnParallelLinear)
xs.mark_sharding(layer.weight, mesh, ('x', None))
logger.debug("Applied column-parallel sharding to %s", layer)
return layer
def partition_row_parallel_linear(layer: torch.nn.Module,
mesh: xs.Mesh) -> torch.nn.Module:
assert isinstance(layer, RowParallelLinear)
xs.mark_sharding(layer.weight, mesh, (None, 'x'))
logger.debug("Applied row-parallel sharding to %s", layer)
return layer
def partition_qkv_parallel_linear(layer: torch.nn.Module,
mesh: xs.Mesh) -> torch.nn.Module:
assert isinstance(layer, QKVParallelLinear)
xla_layer = XlaQKVParallelLinear(layer, mesh)
logger.debug("Applied qkv parallel sharding to %s", layer)
return xla_layer
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([
("QKVParallelLinear", partition_qkv_parallel_linear),
("ColumnParallelLinear", partition_column_parallel_linear),
("RowParallelLinear", partition_row_parallel_linear),
])
def get_fqn(module):
# Get the fully qualified name of the module
return module.__class__.__qualname__
def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
"""
Recursively check a PyTorch model and apply appropriate sharding based on
the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
Args:
model: torch.nn.Module to process
mesh: An XLA SPMD mesh object used for sharding
"""
def _process_module(module, name=None, parent=None):
for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
if get_fqn(module) == module_type:
wrapped_module = wrapping_func(module, mesh)
assert parent is not None and name is not None, (
"Top Level module is not expected to be wrapped.")
if wrapped_module is not module:
# Wrapped module and module are different py object.
# The original module should be replaced by the
# wrapped_module.
logger.debug("replace %s with %s", module, wrapped_module)
setattr(parent, name, wrapped_module)
module = wrapped_module
break
for child_name, child_module in list(module.named_children()):
_process_module(child_module, child_name, module)
_process_module(model)

View File

@ -51,6 +51,7 @@ if TYPE_CHECKING:
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto" VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_XLA_USE_SPMD: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
@ -513,6 +514,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, assert on XLA recompilation after each execution step. # If set, assert on XLA recompilation after each execution step.
"VLLM_XLA_CHECK_RECOMPILATION": "VLLM_XLA_CHECK_RECOMPILATION":
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))), lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))),
# Enable SPMD mode for TPU backend.
"VLLM_XLA_USE_SPMD":
lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
"VLLM_FUSED_MOE_CHUNK_SIZE": "VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),

View File

@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
import time
from typing import Optional
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
logger = init_logger(__name__)
class TPUModelLoader(DefaultModelLoader):
"""
A TPU model loader for model loading under SPMD mode.
"""
def load_model(
self,
vllm_config: VllmConfig,
model_config: ModelConfig,
mesh: Optional[xs.Mesh] = None,
) -> nn.Module:
# Initialize model and load weights on CPU. Then, during SPMD partition,
# weights are sharded and transferred to TPUs.
self.counter_before_loading_weights = time.perf_counter()
model_config = vllm_config.model_config
assert model_config.quantization is None, "Quantization not supported"
target_device = torch.device('cpu')
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
load_format = vllm_config.load_config.load_format
if load_format != "dummy":
weights_to_load = {
name
for name, _ in model.named_parameters()
}
all_weights = self.get_all_weights(model_config, model)
loaded_weights = model.load_weights(all_weights)
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights -
self.counter_before_loading_weights)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and \
loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
else:
logger.info("Use dummy weight during weight loading.")
process_weights_after_loading(model, model_config, target_device)
counter_before_partition = time.perf_counter()
model = model.eval()
model = model.to('xla')
shard_model(model, mesh)
counter_after_partition = time.perf_counter()
logger.info("Partition model took %.2f seconds",
counter_after_partition - counter_before_partition)
# Ensure the model is properly loaded.
self._check_model_is_loaded(mesh, model)
# Need to torch compile after model sharding are done. Because the
# compiler hints ('xs.mark_sharding') are torch ops.
if not model_config.is_multimodal_model:
model.model = torch.compile(model.model, backend="openxla")
else:
model.language_model.model = \
torch.compile(model.language_model.model, backend="openxla")
return model
def _check_model_is_loaded(self, mesh: Optional[xs.Mesh],
model: nn.Module) -> None:
"""
Ensure the model is properly loaded.
1. All model parameters and buffers are on XLA device.
2. Non-SPMD friendly layers are replaced as expected.
"""
device = xm.xla_device()
device_type = str(device.type)
# Check parameters
for name, param in model.named_parameters():
assert param.device.type == device_type, f"Parameter {name} is on \
{param.device.type} instead of {device_type}"
# Check buffers
for name, buffer in model.named_buffers():
assert buffer.device.type == device_type, \
f"Buffer {name} is on {buffer.device.type} instead of \
{device_type}"
for module in model.modules():
if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'):
raise AssertionError("QKVParallelLinear should be replaced by \
XlaQKVParallelLinear under SPMD mode.")

View File

@ -49,7 +49,9 @@ def _make_synced_weight_loader(original_weight_loader):
def _synced_weight_loader(param, *args, **kwargs): def _synced_weight_loader(param, *args, **kwargs):
original_weight_loader(param, *args, **kwargs) original_weight_loader(param, *args, **kwargs)
torch._sync(param) # torch._sync doesn't support, is not needed for CPU tensors.
if param.device != torch.device("cpu"):
torch._sync(param)
return _synced_weight_loader return _synced_weight_loader

View File

@ -7,21 +7,22 @@ from unittest.mock import patch
import numpy as np import numpy as np
import torch import torch
import torch.distributed
import torch.nn as nn import torch.nn as nn
# TPU XLA related # TPU XLA related
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr import torch_xla.runtime as xr
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange) PlaceholderRange)
@ -98,6 +99,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
original_parallel_config: Optional[ParallelConfig] = None,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
@ -105,6 +107,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.lora_config = vllm_config.lora_config self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.original_parallel_config = original_parallel_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
@ -118,6 +121,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.device = device self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
# SPMD Related
self.use_spmd = envs.VLLM_XLA_USE_SPMD
if self.use_spmd:
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
self.enforce_eager = model_config.enforce_eager self.enforce_eager = model_config.enforce_eager
self.num_xla_graphs = 0 self.num_xla_graphs = 0
@ -271,6 +282,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
max_num_mm_items_decoder_budget) max_num_mm_items_decoder_budget)
self.max_num_mm_items_by_modality[modality] = max_num_mm_items self.max_num_mm_items_by_modality[modality] = max_num_mm_items
if not self.use_spmd:
self.sample_from_logits_func = torch.compile(
self.sample_from_logits,
backend="openxla",
fullgraph=True,
dynamic=False)
else:
self.sample_from_logits_func = self.sample_from_logits
def _update_num_xla_graphs(self, case_str): def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp: if not check_comp:
@ -825,9 +845,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
logits = self.structured_decode(require_struct_decoding, logits = self.structured_decode(require_struct_decoding,
grammar_bitmask_padded, logits, grammar_bitmask_padded, logits,
arange) arange)
selected_token_ids = self.sample_from_logits(logits, selected_token_ids = self.sample_from_logits_func(
tpu_sampling_metadata) logits, tpu_sampling_metadata)
# NOTE (NickLucche) Use the original logits (before any penalties or # NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it due # temperature scaling) for the top-k logprobs. We can't enforce it due
# to recompilations outside torch.compiled code, so just make sure # to recompilations outside torch.compiled code, so just make sure
@ -935,18 +954,26 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"vllm.model_executor.layers.vocab_parallel_embedding." "vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank", "get_tensor_model_parallel_rank",
return_value=xm_tp_rank): return_value=xm_tp_rank):
# model = get_model(vllm_config=self.vllm_config) if self.use_spmd:
model_loader = get_model_loader(self.load_config) tpu_loader = TPUModelLoader(
if not hasattr(self, "model"): load_config=self.vllm_config.load_config)
logger.info("Loading model from scratch...") model = tpu_loader.load_model(
model = model_loader.load_model(vllm_config=self.vllm_config, vllm_config=self.vllm_config,
model_config=self.model_config) model_config=self.vllm_config.model_config,
mesh=self.mesh)
else: else:
logger.info( # model = get_model(vllm_config=self.vllm_config)
"Model was already initialized. Loading weights inplace..." model_loader = get_model_loader(self.load_config)
) if not hasattr(self, "model"):
model_loader.load_weights(self.model, logger.info("Loading model from scratch...")
model_config=self.model_config) model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info("Model was already initialized. \
Loading weights inplace...")
model_loader.load_weights(self.model,
model_config=self.model_config)
if self.lora_config is not None: if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config, model = self.load_lora_model(model, self.model_config,
self.scheduler_config, self.scheduler_config,
@ -970,31 +997,25 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device=self.device) device=self.device)
else: else:
input_ids = torch.zeros((num_tokens), input_ids = torch.zeros((num_tokens),
dtype=torch.int32, dtype=torch.int32).to(self.device)
device=self.device)
inputs_embeds = None inputs_embeds = None
actual_num_reqs = min(num_tokens, self.max_num_reqs) actual_num_reqs = min(num_tokens, self.max_num_reqs)
position_ids = torch.zeros(num_tokens, position_ids = torch.zeros(num_tokens,
dtype=torch.int32, dtype=torch.int32).to(self.device)
device=self.device)
slot_mapping = torch.zeros(num_tokens, slot_mapping = torch.zeros(num_tokens,
dtype=torch.int64, dtype=torch.int64).to(self.device)
device=self.device)
block_tables = torch.zeros( block_tables = torch.zeros(
(self.max_num_reqs, self.block_table_cpu.shape[1]), (self.max_num_reqs, self.block_table_cpu.shape[1]),
dtype=torch.int32, dtype=torch.int32).to(self.device)
device=self.device)
query_lens = [1] * self.max_num_reqs query_lens = [1] * self.max_num_reqs
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.int32), dtype=torch.int32),
dim=0, dim=0,
dtype=torch.int32).to(self.device) dtype=torch.int32).to(self.device)
context_lens = torch.ones((self.max_num_reqs, ), context_lens = torch.ones((self.max_num_reqs, ),
dtype=torch.int32, dtype=torch.int32).to(self.device)
device=self.device)
num_seqs = torch.tensor([actual_num_reqs], num_seqs = torch.tensor([actual_num_reqs],
dtype=torch.int32, dtype=torch.int32).to(self.device)
device=self.device)
attn_metadata = PallasMetadata( attn_metadata = PallasMetadata(
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
block_tables=block_tables, block_tables=block_tables,
@ -1198,7 +1219,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
with self.maybe_select_dummy_loras( with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs], self.lora_config, np.array([num_reqs],
dtype=np.int32)): dtype=np.int32)):
self.sample_from_logits(dummy_logits, sampling_metadata) self.sample_from_logits_func(dummy_logits,
sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs) logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
@ -1332,14 +1354,22 @@ class TPUModelRunner(LoRAModelRunnerMixin):
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
if self.use_spmd:
num_kv_heads = kv_cache_spec.num_kv_heads
assert self.original_parallel_config is not None
tp_size = \
self.original_parallel_config.tensor_parallel_size
# TODO: Handle kv cache duplication under SPMD mode.
assert num_kv_heads % tp_size == 0, (
f"num_kv_heads {num_kv_heads} must be divisible by "
f"tp_size {tp_size} under SPMD mode")
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
tpu_kv_cache = torch.zeros(kv_cache_shape, tpu_kv_cache = torch.zeros(kv_cache_shape,
dtype=dtype, dtype=dtype).to(self.device)
device=self.device)
kv_caches[layer_name] = tpu_kv_cache kv_caches[layer_name] = tpu_kv_cache
else: else:
@ -1350,6 +1380,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.static_forward_context, self.vllm_config.compilation_config.static_forward_context,
self.kv_caches) self.kv_caches)
if self.use_spmd:
# Shard KV Cache
for cache in self.kv_caches:
xs.mark_sharding(cache, self.mesh, (None, 'x', None, None))
def reset_dynamo_cache(self): def reset_dynamo_cache(self):
if self.is_multimodal_model: if self.is_multimodal_model:
compiled_model = self.model.get_language_model().model compiled_model = self.model.get_language_model().model
@ -1370,7 +1405,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
sample_hidden_states: torch.Tensor) -> torch.Tensor: sample_hidden_states: torch.Tensor) -> torch.Tensor:
return self.model.compute_logits(sample_hidden_states, None) return self.model.compute_logits(sample_hidden_states, None)
@torch.compile(backend="openxla", fullgraph=True, dynamic=False) # TODO: Under SPMD mode, sample_from_logits has correctness issue.
# Re-enable the torch.compile once the issue is fixed in torchxla.
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_logits( def sample_from_logits(
self, logits: torch.Tensor, self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:

View File

@ -45,6 +45,15 @@ class TPUWorker:
self.lora_config = vllm_config.lora_config self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.use_spmd = envs.VLLM_XLA_USE_SPMD
self.original_parallel_config = None
if self.use_spmd:
# Under SPMD mode, distributed env is initialized as if there is
# only one worker/device.
self.original_parallel_config = self.parallel_config
self.parallel_config.tensor_parallel_size = 1
self.parallel_config.pipeline_parallel_size = 1
self.parallel_config.world_size = 1
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
@ -95,10 +104,9 @@ class TPUWorker:
torch.set_default_dtype(self.model_config.dtype) torch.set_default_dtype(self.model_config.dtype)
# Initialize the distributed environment. # Initialize the distributed environment.
init_tpu_worker_distributed_environment(self.parallel_config, self._init_tpu_worker_distributed_environment(
self.rank, self.parallel_config, self.rank, self.distributed_init_method,
self.distributed_init_method, self.local_rank)
self.local_rank)
# Device initialization should happen after initializing # Device initialization should happen after initializing
# the distributed runtime. # the distributed runtime.
@ -132,7 +140,9 @@ class TPUWorker:
xr.initialize_cache(per_rank_path, readonly=False) xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device. # Init ModelRunner here, so that we have access to self.device.
self.model_runner = TPUModelRunner(self.vllm_config, self.device) self.model_runner = \
TPUModelRunner(self.vllm_config, self.device,
self.original_parallel_config)
if rank == 0: if rank == 0:
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
@ -147,9 +157,7 @@ class TPUWorker:
# Use an empty tensor instead of `None`` to force Dynamo to pass # Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``. # it by reference, rather by specializing on the value ``None``.
tpu_kv_cache = torch.tensor([], tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
dtype=dtype,
device=self.device)
kv_caches[layer_name] = tpu_kv_cache kv_caches[layer_name] = tpu_kv_cache
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -178,9 +186,20 @@ class TPUWorker:
# Get the maximum amount of memory used by the model weights and # Get the maximum amount of memory used by the model weights and
# intermediate activations. # intermediate activations.
m = xm.get_memory_info(self.device) if self.use_spmd:
total_memory_size = m["bytes_limit"] # This is a workaround for the TPU SPMD mode. The get_memory_info
current_mem = m["bytes_used"] # API doesn't work with SPMD mode in PyTorch/XLA.
# TODO: use xm.get_memory_info for SPMD once it's supported in
# PyTorch/XLA.
import tpu_info
chip_type, _ = tpu_info.device.get_local_chips()
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
total_memory_size = device_usage[0].total_memory
current_mem = device_usage[0].memory_usage
else:
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
current_mem = m["bytes_used"]
# Ideally we would use profiled = m["peak_bytes_used"] to # Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during # get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and # compilation / weight loading that impacts the peak and
@ -241,28 +260,30 @@ class TPUWorker:
# worker will always be healthy as long as it's running. # worker will always be healthy as long as it's running.
return return
def _init_tpu_worker_distributed_environment(
def init_tpu_worker_distributed_environment( self,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
if self.use_spmd:
# NOTE(woosuk): This is just to initialize the TP group and broadcast xr.use_spmd()
# the input objects on CPU. The all-reduce and all-gather ops on TPU # NOTE(woosuk): This is just to initialize the TP group and broadcast
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their # the input objects on CPU. The all-reduce and all-gather ops on TPU
# own context. # are invoked by `xm.all_reduce` and `xm.all_gather` which use their
init_distributed_environment( # own context.
world_size=parallel_config.world_size, init_distributed_environment(
rank=rank, world_size=parallel_config.world_size,
local_rank=local_rank, rank=rank,
distributed_init_method=distributed_init_method, local_rank=local_rank,
backend="gloo", distributed_init_method=distributed_init_method,
) backend="gloo",
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, )
parallel_config.pipeline_parallel_size) ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
try: try: