[Hardware][TPU] Initial support of model parallelism with single worker using SPMD (#18011)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: Hossein Sarshar <hossein.sarshar@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Siyuan Liu 2025-06-02 17:06:20 -07:00 committed by GitHub
parent c57d577e8d
commit 9112b443a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 605 additions and 72 deletions

View File

@ -155,6 +155,10 @@ run_and_track_test 12 "test_moe_pallas.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
run_and_track_test 13 "test_lora.py" \
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
run_and_track_test 14 "test_tpu_qkv_linear.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then

View File

@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
from vllm import LLM, SamplingParams
prompts = [
@ -18,14 +21,28 @@ sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)
def main():
parser = argparse.ArgumentParser(description="TPU offline inference example")
parser.add_argument("--use-spmd", action="store_true", help="Enable SPMD mode")
args = parser.parse_args()
llm_args = {
"model": "Qwen/Qwen2-1.5B-Instruct",
"max_num_batched_tokens": 64,
"max_num_seqs": 4,
"max_model_len": 128,
}
if args.use_spmd:
os.environ["VLLM_XLA_USE_SPMD"] = "1"
# Can only hardcode the number of chips for now.
# calling xr.global_runtime_device_count() beforeing init SPMD env in
# torch_xla will mess up the distributed env.
llm_args["tensor_parallel_size"] = 8
# Use Llama, for num_kv_heads = 8.
llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct"
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
llm = LLM(
model="Qwen/Qwen2-1.5B-Instruct",
max_num_batched_tokens=64,
max_num_seqs=4,
max_model_len=128,
)
llm = LLM(**llm_args)
outputs = llm.generate(prompts, sampling_params)
print("-" * 50)
for output, answer in zip(outputs, answers):

View File

@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
import gc
import tempfile
import numpy as np
import pytest
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tpu import TPUModelLoader
def _setup_environment(model):
engine_args = EngineArgs(model=model, )
vllm_config = engine_args.create_engine_config()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
1,
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo")
# Under single worker mode, full model is init first and then
# partitioned using GSPMD.
ensure_model_parallel_initialized(1, 1)
return vllm_config
MESH = None
def _get_spmd_mesh():
global MESH
if MESH is None:
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
return MESH
@pytest.mark.parametrize("model", [
"Qwen/Qwen2-1.5B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct",
])
def test_tpu_model_loader(model):
# Skip the 70B test if there are less than 8 chips
# TODO: Query using torch xla API, the query API is not working
# with SPMD now. However, This test is running under SPMD mode.
if '70B' in model and xr.global_runtime_device_count() < 8:
pytest.skip(
"Skipping 70B model if the TPU VM has less than 8 chips to \
avoid OOM.")
vllm_config = _setup_environment(model)
loader = TPUModelLoader(load_config=vllm_config.load_config)
mesh = _get_spmd_mesh()
model = loader.load_model(vllm_config, vllm_config.model_config, mesh)
del model
gc.collect()

View File

@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
import tempfile
import numpy as np
import pytest
import torch
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.linear import QKVParallelLinear
@pytest.fixture(autouse=True)
def setup_environment():
# This is a fake config used for init dist env.
# QKVParallelLinear needs dist env to be initialized.
engine_args = EngineArgs(
model="Qwen/Qwen2-1.5B-Instruct",
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
1,
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo")
ensure_model_parallel_initialized(1, 1)
yield
MESH = None
def _get_spmd_mesh():
global MESH
if MESH is None:
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
return MESH
@pytest.mark.parametrize("bias", [False, True])
# `xr.use_spmd()` will set a global state, and this state is not reversible.
# Therefore, non-SPMD tests should be run before SPMD tests.
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
@pytest.mark.parametrize("device", ['cpu', 'xla'])
@torch.no_grad()
def test_xla_qkv_linear(bias, mesh, device):
torch.manual_seed(123)
qkv_linear = QKVParallelLinear(
hidden_size=4096,
head_size=128,
total_num_heads=32,
total_num_kv_heads=8,
bias=bias,
params_dtype=torch.bfloat16,
return_bias=False,
)
qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10
if bias:
qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data)
xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh)
qkv_linear = qkv_linear.to(device)
xla_qkv_linear = xla_qkv_linear.to(device)
input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10
input_tensor = input_tensor.to(device)
output = qkv_linear(input_tensor)
xla_output = xla_qkv_linear(input_tensor)
assert torch.allclose(output.cpu(), xla_output.cpu())

View File

@ -1901,6 +1901,8 @@ class ParallelConfig:
if current_platform.is_neuron():
# neuron uses single process to control multiple devices
backend = "uni"
elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
backend = "uni"
elif (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):
if not ray_found:

View File

@ -0,0 +1,177 @@
# SPDX-License-Identifier: Apache-2.0
from collections import OrderedDict
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.distributed.spmd as xs
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
logger = init_logger(__name__)
class XlaQKVParallelLinear(nn.Module):
def __init__(self,
qkv_linear: nn.Module,
mesh: Optional["xs.Mesh"] = None):
super().__init__()
assert isinstance(qkv_linear, QKVParallelLinear)
self.skip_bias_add = qkv_linear.skip_bias_add
self.return_bias = qkv_linear.return_bias
assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."
self.q_weight: Parameter
self.k_weight: Parameter
self.v_weight: Parameter
self.q_bias: Optional[Parameter]
self.k_bias: Optional[Parameter]
self.v_bias: Optional[Parameter]
self._load_weights_from_qkv_linear(qkv_linear)
if mesh is not None:
self._shard_weight(mesh)
def _shard_weight(self, mesh: "xs.Mesh"):
self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False)
self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False)
self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False)
xs.mark_sharding(self.q_weight, mesh, ('x', None))
xs.mark_sharding(self.k_weight, mesh, ('x', None))
xs.mark_sharding(self.v_weight, mesh, ('x', None))
if self.q_bias is not None:
assert self.k_bias is not None and self.v_bias is not None, \
"QKVParallelLinear should have q, k, and v biases together."
self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False)
xs.mark_sharding(self.q_bias, mesh, ('x', ))
self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False)
xs.mark_sharding(self.k_bias, mesh, ('x', ))
self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False)
xs.mark_sharding(self.v_bias, mesh, ('x', ))
def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
# The weight of qkv linear is a concatenation of q, k, and v weights
# along the output dimension.
qkv_weight = qkv_linear.weight.data.cpu()
q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size],
requires_grad=False)
v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:],
requires_grad=False)
self.register_parameter("q_weight", q_weight)
self.register_parameter("k_weight", k_weight)
self.register_parameter("v_weight", v_weight)
if qkv_linear.bias is not None:
q_bias = Parameter(qkv_linear.bias[:q_proj_size],
requires_grad=False)
k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size +
k_proj_size],
requires_grad=False)
v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:],
requires_grad=False)
self.register_parameter("q_bias", q_bias)
self.register_parameter("k_bias", k_bias)
self.register_parameter("v_bias", v_bias)
else:
self.register_parameter("q_bias", None)
self.register_parameter("k_bias", None)
self.register_parameter("v_bias", None)
def forward(self, input):
# Same forward functionality as QKVParallelLinear, but doing qkv porj
# separately.
q_bias = self.q_bias if not self.skip_bias_add else None
k_bias = self.k_bias if not self.skip_bias_add else None
v_bias = self.v_bias if not self.skip_bias_add else None
q_proj = F.linear(input, self.q_weight, q_bias)
k_proj = F.linear(input, self.k_weight, k_bias)
v_proj = F.linear(input, self.v_weight, v_bias)
# The q/k/v projections will be split outside of the QKVParallelLinear.
# Because we are replacing XlaQKVParallelLinear with the
# QKVParallelLinear, we need to concatenate q, k, and v projections to
# match the output shape of the QKVParallelLinear implementation even if
# it seems to be redundant.
# The concat and the following split will be noop, and should be
# optimized away by the compiler.
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \
self.skip_bias_add else None
if not self.return_bias:
return qkv_proj
return qkv_proj, output_bias
def partition_column_parallel_linear(layer: torch.nn.Module,
mesh: xs.Mesh) -> torch.nn.Module:
assert isinstance(layer, ColumnParallelLinear)
xs.mark_sharding(layer.weight, mesh, ('x', None))
logger.debug("Applied column-parallel sharding to %s", layer)
return layer
def partition_row_parallel_linear(layer: torch.nn.Module,
mesh: xs.Mesh) -> torch.nn.Module:
assert isinstance(layer, RowParallelLinear)
xs.mark_sharding(layer.weight, mesh, (None, 'x'))
logger.debug("Applied row-parallel sharding to %s", layer)
return layer
def partition_qkv_parallel_linear(layer: torch.nn.Module,
mesh: xs.Mesh) -> torch.nn.Module:
assert isinstance(layer, QKVParallelLinear)
xla_layer = XlaQKVParallelLinear(layer, mesh)
logger.debug("Applied qkv parallel sharding to %s", layer)
return xla_layer
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([
("QKVParallelLinear", partition_qkv_parallel_linear),
("ColumnParallelLinear", partition_column_parallel_linear),
("RowParallelLinear", partition_row_parallel_linear),
])
def get_fqn(module):
# Get the fully qualified name of the module
return module.__class__.__qualname__
def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
"""
Recursively check a PyTorch model and apply appropriate sharding based on
the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
Args:
model: torch.nn.Module to process
mesh: An XLA SPMD mesh object used for sharding
"""
def _process_module(module, name=None, parent=None):
for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
if get_fqn(module) == module_type:
wrapped_module = wrapping_func(module, mesh)
assert parent is not None and name is not None, (
"Top Level module is not expected to be wrapped.")
if wrapped_module is not module:
# Wrapped module and module are different py object.
# The original module should be replaced by the
# wrapped_module.
logger.debug("replace %s with %s", module, wrapped_module)
setattr(parent, name, wrapped_module)
module = wrapped_module
break
for child_name, child_module in list(module.named_children()):
_process_module(child_module, child_name, module)
_process_module(model)

View File

@ -51,6 +51,7 @@ if TYPE_CHECKING:
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_XLA_USE_SPMD: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
@ -513,6 +514,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, assert on XLA recompilation after each execution step.
"VLLM_XLA_CHECK_RECOMPILATION":
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))),
# Enable SPMD mode for TPU backend.
"VLLM_XLA_USE_SPMD":
lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
"VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),

View File

@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
import time
from typing import Optional
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
logger = init_logger(__name__)
class TPUModelLoader(DefaultModelLoader):
"""
A TPU model loader for model loading under SPMD mode.
"""
def load_model(
self,
vllm_config: VllmConfig,
model_config: ModelConfig,
mesh: Optional[xs.Mesh] = None,
) -> nn.Module:
# Initialize model and load weights on CPU. Then, during SPMD partition,
# weights are sharded and transferred to TPUs.
self.counter_before_loading_weights = time.perf_counter()
model_config = vllm_config.model_config
assert model_config.quantization is None, "Quantization not supported"
target_device = torch.device('cpu')
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
load_format = vllm_config.load_config.load_format
if load_format != "dummy":
weights_to_load = {
name
for name, _ in model.named_parameters()
}
all_weights = self.get_all_weights(model_config, model)
loaded_weights = model.load_weights(all_weights)
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights -
self.counter_before_loading_weights)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and \
loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
else:
logger.info("Use dummy weight during weight loading.")
process_weights_after_loading(model, model_config, target_device)
counter_before_partition = time.perf_counter()
model = model.eval()
model = model.to('xla')
shard_model(model, mesh)
counter_after_partition = time.perf_counter()
logger.info("Partition model took %.2f seconds",
counter_after_partition - counter_before_partition)
# Ensure the model is properly loaded.
self._check_model_is_loaded(mesh, model)
# Need to torch compile after model sharding are done. Because the
# compiler hints ('xs.mark_sharding') are torch ops.
if not model_config.is_multimodal_model:
model.model = torch.compile(model.model, backend="openxla")
else:
model.language_model.model = \
torch.compile(model.language_model.model, backend="openxla")
return model
def _check_model_is_loaded(self, mesh: Optional[xs.Mesh],
model: nn.Module) -> None:
"""
Ensure the model is properly loaded.
1. All model parameters and buffers are on XLA device.
2. Non-SPMD friendly layers are replaced as expected.
"""
device = xm.xla_device()
device_type = str(device.type)
# Check parameters
for name, param in model.named_parameters():
assert param.device.type == device_type, f"Parameter {name} is on \
{param.device.type} instead of {device_type}"
# Check buffers
for name, buffer in model.named_buffers():
assert buffer.device.type == device_type, \
f"Buffer {name} is on {buffer.device.type} instead of \
{device_type}"
for module in model.modules():
if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'):
raise AssertionError("QKVParallelLinear should be replaced by \
XlaQKVParallelLinear under SPMD mode.")

View File

@ -49,7 +49,9 @@ def _make_synced_weight_loader(original_weight_loader):
def _synced_weight_loader(param, *args, **kwargs):
original_weight_loader(param, *args, **kwargs)
torch._sync(param)
# torch._sync doesn't support, is not needed for CPU tensors.
if param.device != torch.device("cpu"):
torch._sync(param)
return _synced_weight_loader

View File

@ -7,21 +7,22 @@ from unittest.mock import patch
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
# TPU XLA related
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange)
@ -98,6 +99,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self,
vllm_config: VllmConfig,
device: torch.device,
original_parallel_config: Optional[ParallelConfig] = None,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
@ -105,6 +107,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.original_parallel_config = original_parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
@ -118,6 +121,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
# SPMD Related
self.use_spmd = envs.VLLM_XLA_USE_SPMD
if self.use_spmd:
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
self.enforce_eager = model_config.enforce_eager
self.num_xla_graphs = 0
@ -271,6 +282,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
max_num_mm_items_decoder_budget)
self.max_num_mm_items_by_modality[modality] = max_num_mm_items
if not self.use_spmd:
self.sample_from_logits_func = torch.compile(
self.sample_from_logits,
backend="openxla",
fullgraph=True,
dynamic=False)
else:
self.sample_from_logits_func = self.sample_from_logits
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
@ -825,9 +845,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
logits = self.structured_decode(require_struct_decoding,
grammar_bitmask_padded, logits,
arange)
selected_token_ids = self.sample_from_logits(logits,
tpu_sampling_metadata)
selected_token_ids = self.sample_from_logits_func(
logits, tpu_sampling_metadata)
# NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it due
# to recompilations outside torch.compiled code, so just make sure
@ -935,18 +954,26 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
# model = get_model(vllm_config=self.vllm_config)
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
model = model_loader.load_model(vllm_config=self.vllm_config,
model_config=self.model_config)
if self.use_spmd:
tpu_loader = TPUModelLoader(
load_config=self.vllm_config.load_config)
model = tpu_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.model_config,
mesh=self.mesh)
else:
logger.info(
"Model was already initialized. Loading weights inplace..."
)
model_loader.load_weights(self.model,
model_config=self.model_config)
# model = get_model(vllm_config=self.vllm_config)
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info("Model was already initialized. \
Loading weights inplace...")
model_loader.load_weights(self.model,
model_config=self.model_config)
if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config,
self.scheduler_config,
@ -970,31 +997,25 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device=self.device)
else:
input_ids = torch.zeros((num_tokens),
dtype=torch.int32,
device=self.device)
dtype=torch.int32).to(self.device)
inputs_embeds = None
actual_num_reqs = min(num_tokens, self.max_num_reqs)
position_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros(num_tokens,
dtype=torch.int64,
device=self.device)
dtype=torch.int64).to(self.device)
block_tables = torch.zeros(
(self.max_num_reqs, self.block_table_cpu.shape[1]),
dtype=torch.int32,
device=self.device)
dtype=torch.int32).to(self.device)
query_lens = [1] * self.max_num_reqs
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.int32),
dim=0,
dtype=torch.int32).to(self.device)
context_lens = torch.ones((self.max_num_reqs, ),
dtype=torch.int32,
device=self.device)
dtype=torch.int32).to(self.device)
num_seqs = torch.tensor([actual_num_reqs],
dtype=torch.int32,
device=self.device)
dtype=torch.int32).to(self.device)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=block_tables,
@ -1198,7 +1219,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs],
dtype=np.int32)):
self.sample_from_logits(dummy_logits, sampling_metadata)
self.sample_from_logits_func(dummy_logits,
sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
@ -1332,14 +1354,22 @@ class TPUModelRunner(LoRAModelRunnerMixin):
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec):
if self.use_spmd:
num_kv_heads = kv_cache_spec.num_kv_heads
assert self.original_parallel_config is not None
tp_size = \
self.original_parallel_config.tensor_parallel_size
# TODO: Handle kv cache duplication under SPMD mode.
assert num_kv_heads % tp_size == 0, (
f"num_kv_heads {num_kv_heads} must be divisible by "
f"tp_size {tp_size} under SPMD mode")
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
tpu_kv_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
dtype=dtype).to(self.device)
kv_caches[layer_name] = tpu_kv_cache
else:
@ -1350,6 +1380,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
if self.use_spmd:
# Shard KV Cache
for cache in self.kv_caches:
xs.mark_sharding(cache, self.mesh, (None, 'x', None, None))
def reset_dynamo_cache(self):
if self.is_multimodal_model:
compiled_model = self.model.get_language_model().model
@ -1370,7 +1405,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
sample_hidden_states: torch.Tensor) -> torch.Tensor:
return self.model.compute_logits(sample_hidden_states, None)
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
# TODO: Under SPMD mode, sample_from_logits has correctness issue.
# Re-enable the torch.compile once the issue is fixed in torchxla.
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_logits(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:

View File

@ -45,6 +45,15 @@ class TPUWorker:
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.use_spmd = envs.VLLM_XLA_USE_SPMD
self.original_parallel_config = None
if self.use_spmd:
# Under SPMD mode, distributed env is initialized as if there is
# only one worker/device.
self.original_parallel_config = self.parallel_config
self.parallel_config.tensor_parallel_size = 1
self.parallel_config.pipeline_parallel_size = 1
self.parallel_config.world_size = 1
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
@ -95,10 +104,9 @@ class TPUWorker:
torch.set_default_dtype(self.model_config.dtype)
# Initialize the distributed environment.
init_tpu_worker_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
self.local_rank)
self._init_tpu_worker_distributed_environment(
self.parallel_config, self.rank, self.distributed_init_method,
self.local_rank)
# Device initialization should happen after initializing
# the distributed runtime.
@ -132,7 +140,9 @@ class TPUWorker:
xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
self.model_runner = \
TPUModelRunner(self.vllm_config, self.device,
self.original_parallel_config)
if rank == 0:
# If usage stat is enabled, collect relevant info.
@ -147,9 +157,7 @@ class TPUWorker:
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
tpu_kv_cache = torch.tensor([],
dtype=dtype,
device=self.device)
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
kv_caches[layer_name] = tpu_kv_cache
else:
raise NotImplementedError(
@ -178,9 +186,20 @@ class TPUWorker:
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
current_mem = m["bytes_used"]
if self.use_spmd:
# This is a workaround for the TPU SPMD mode. The get_memory_info
# API doesn't work with SPMD mode in PyTorch/XLA.
# TODO: use xm.get_memory_info for SPMD once it's supported in
# PyTorch/XLA.
import tpu_info
chip_type, _ = tpu_info.device.get_local_chips()
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
total_memory_size = device_usage[0].total_memory
current_mem = device_usage[0].memory_usage
else:
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
current_mem = m["bytes_used"]
# Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and
@ -241,28 +260,30 @@ class TPUWorker:
# worker will always be healthy as long as it's running.
return
def init_tpu_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def _init_tpu_worker_distributed_environment(
self,
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
if self.use_spmd:
xr.use_spmd()
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
try: