diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 6102431456210..3212b660ec356 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -155,6 +155,10 @@ run_and_track_test 12 "test_moe_pallas.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" run_and_track_test 13 "test_lora.py" \ "VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py" +run_and_track_test 14 "test_tpu_qkv_linear.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py" +run_and_track_test 15 "test_spmd_model_weight_loading.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" # After all tests have been attempted, exit with the overall status. if [ "$overall_script_exit_code" -ne 0 ]; then diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index e4a75b3f93803..f3c2859d44d17 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +import argparse +import os + from vllm import LLM, SamplingParams prompts = [ @@ -18,14 +21,28 @@ sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16) def main(): + parser = argparse.ArgumentParser(description="TPU offline inference example") + parser.add_argument("--use-spmd", action="store_true", help="Enable SPMD mode") + args = parser.parse_args() + + llm_args = { + "model": "Qwen/Qwen2-1.5B-Instruct", + "max_num_batched_tokens": 64, + "max_num_seqs": 4, + "max_model_len": 128, + } + if args.use_spmd: + os.environ["VLLM_XLA_USE_SPMD"] = "1" + # Can only hardcode the number of chips for now. + # calling xr.global_runtime_device_count() beforeing init SPMD env in + # torch_xla will mess up the distributed env. + llm_args["tensor_parallel_size"] = 8 + # Use Llama, for num_kv_heads = 8. + llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct" + # Set `enforce_eager=True` to avoid ahead-of-time compilation. # In real workloads, `enforace_eager` should be `False`. - llm = LLM( - model="Qwen/Qwen2-1.5B-Instruct", - max_num_batched_tokens=64, - max_num_seqs=4, - max_model_len=128, - ) + llm = LLM(**llm_args) outputs = llm.generate(prompts, sampling_params) print("-" * 50) for output, answer in zip(outputs, answers): diff --git a/tests/v1/tpu/test_spmd_model_weight_loading.py b/tests/v1/tpu/test_spmd_model_weight_loading.py new file mode 100644 index 0000000000000..d36edfc3fb618 --- /dev/null +++ b/tests/v1/tpu/test_spmd_model_weight_loading.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +import gc +import tempfile + +import numpy as np +import pytest +import torch_xla.distributed.spmd as xs +import torch_xla.runtime as xr + +from vllm.config import set_current_vllm_config +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.model_loader.tpu import TPUModelLoader + + +def _setup_environment(model): + engine_args = EngineArgs(model=model, ) + vllm_config = engine_args.create_engine_config() + with set_current_vllm_config(vllm_config): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + 1, + 0, + local_rank=0, + distributed_init_method=f"file://{temp_file}", + backend="gloo") + # Under single worker mode, full model is init first and then + # partitioned using GSPMD. + ensure_model_parallel_initialized(1, 1) + return vllm_config + + +MESH = None + + +def _get_spmd_mesh(): + global MESH + if MESH is None: + xr.use_spmd() + num_devices = xr.global_runtime_device_count() + mesh_shape = (num_devices, 1) + device_ids = np.array(range(num_devices)) + MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + return MESH + + +@pytest.mark.parametrize("model", [ + "Qwen/Qwen2-1.5B-Instruct", + "meta-llama/Llama-3.1-8B-Instruct", + "meta-llama/Llama-3.1-70B-Instruct", +]) +def test_tpu_model_loader(model): + # Skip the 70B test if there are less than 8 chips + # TODO: Query using torch xla API, the query API is not working + # with SPMD now. However, This test is running under SPMD mode. + if '70B' in model and xr.global_runtime_device_count() < 8: + pytest.skip( + "Skipping 70B model if the TPU VM has less than 8 chips to \ + avoid OOM.") + + vllm_config = _setup_environment(model) + loader = TPUModelLoader(load_config=vllm_config.load_config) + mesh = _get_spmd_mesh() + model = loader.load_model(vllm_config, vllm_config.model_config, mesh) + del model + gc.collect() diff --git a/tests/v1/tpu/test_tpu_qkv_linear.py b/tests/v1/tpu/test_tpu_qkv_linear.py new file mode 100644 index 0000000000000..b98570f01a7f2 --- /dev/null +++ b/tests/v1/tpu/test_tpu_qkv_linear.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +import tempfile + +import numpy as np +import pytest +import torch +import torch_xla.distributed.spmd as xs +import torch_xla.runtime as xr + +from vllm.config import set_current_vllm_config +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.layers.linear import QKVParallelLinear + + +@pytest.fixture(autouse=True) +def setup_environment(): + # This is a fake config used for init dist env. + # QKVParallelLinear needs dist env to be initialized. + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + + vllm_config = engine_args.create_engine_config() + + with set_current_vllm_config(vllm_config): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + 1, + 0, + local_rank=0, + distributed_init_method=f"file://{temp_file}", + backend="gloo") + ensure_model_parallel_initialized(1, 1) + yield + + +MESH = None + + +def _get_spmd_mesh(): + global MESH + if MESH is None: + xr.use_spmd() + num_devices = xr.global_runtime_device_count() + mesh_shape = (num_devices, 1) + device_ids = np.array(range(num_devices)) + MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + return MESH + + +@pytest.mark.parametrize("bias", [False, True]) +# `xr.use_spmd()` will set a global state, and this state is not reversible. +# Therefore, non-SPMD tests should be run before SPMD tests. +@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()]) +@pytest.mark.parametrize("device", ['cpu', 'xla']) +@torch.no_grad() +def test_xla_qkv_linear(bias, mesh, device): + torch.manual_seed(123) + + qkv_linear = QKVParallelLinear( + hidden_size=4096, + head_size=128, + total_num_heads=32, + total_num_kv_heads=8, + bias=bias, + params_dtype=torch.bfloat16, + return_bias=False, + ) + + qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10 + if bias: + qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data) + + xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh) + + qkv_linear = qkv_linear.to(device) + xla_qkv_linear = xla_qkv_linear.to(device) + input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10 + input_tensor = input_tensor.to(device) + + output = qkv_linear(input_tensor) + xla_output = xla_qkv_linear(input_tensor) + assert torch.allclose(output.cpu(), xla_output.cpu()) diff --git a/vllm/config.py b/vllm/config.py index d0891d670b76d..1bd53e35b0532 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1901,6 +1901,8 @@ class ParallelConfig: if current_platform.is_neuron(): # neuron uses single process to control multiple devices backend = "uni" + elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: + backend = "uni" elif (current_platform.is_cuda() and cuda_device_count_stateless() < self.world_size): if not ray_found: diff --git a/vllm/distributed/tpu_distributed_utils.py b/vllm/distributed/tpu_distributed_utils.py new file mode 100644 index 0000000000000..36ab2eb3a62f6 --- /dev/null +++ b/vllm/distributed/tpu_distributed_utils.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections import OrderedDict +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_xla.distributed.spmd as xs +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) + +logger = init_logger(__name__) + + +class XlaQKVParallelLinear(nn.Module): + + def __init__(self, + qkv_linear: nn.Module, + mesh: Optional["xs.Mesh"] = None): + super().__init__() + assert isinstance(qkv_linear, QKVParallelLinear) + self.skip_bias_add = qkv_linear.skip_bias_add + self.return_bias = qkv_linear.return_bias + assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD." + + self.q_weight: Parameter + self.k_weight: Parameter + self.v_weight: Parameter + self.q_bias: Optional[Parameter] + self.k_bias: Optional[Parameter] + self.v_bias: Optional[Parameter] + self._load_weights_from_qkv_linear(qkv_linear) + if mesh is not None: + self._shard_weight(mesh) + + def _shard_weight(self, mesh: "xs.Mesh"): + self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False) + self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False) + self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False) + xs.mark_sharding(self.q_weight, mesh, ('x', None)) + xs.mark_sharding(self.k_weight, mesh, ('x', None)) + xs.mark_sharding(self.v_weight, mesh, ('x', None)) + if self.q_bias is not None: + assert self.k_bias is not None and self.v_bias is not None, \ + "QKVParallelLinear should have q, k, and v biases together." + self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False) + xs.mark_sharding(self.q_bias, mesh, ('x', )) + self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False) + xs.mark_sharding(self.k_bias, mesh, ('x', )) + self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False) + xs.mark_sharding(self.v_bias, mesh, ('x', )) + + def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module): + q_proj_size, k_proj_size, _ = qkv_linear.output_sizes + # The weight of qkv linear is a concatenation of q, k, and v weights + # along the output dimension. + qkv_weight = qkv_linear.weight.data.cpu() + q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False) + k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size], + requires_grad=False) + v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:], + requires_grad=False) + self.register_parameter("q_weight", q_weight) + self.register_parameter("k_weight", k_weight) + self.register_parameter("v_weight", v_weight) + + if qkv_linear.bias is not None: + q_bias = Parameter(qkv_linear.bias[:q_proj_size], + requires_grad=False) + k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size + + k_proj_size], + requires_grad=False) + v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:], + requires_grad=False) + self.register_parameter("q_bias", q_bias) + self.register_parameter("k_bias", k_bias) + self.register_parameter("v_bias", v_bias) + else: + self.register_parameter("q_bias", None) + self.register_parameter("k_bias", None) + self.register_parameter("v_bias", None) + + def forward(self, input): + # Same forward functionality as QKVParallelLinear, but doing qkv porj + # separately. + q_bias = self.q_bias if not self.skip_bias_add else None + k_bias = self.k_bias if not self.skip_bias_add else None + v_bias = self.v_bias if not self.skip_bias_add else None + q_proj = F.linear(input, self.q_weight, q_bias) + k_proj = F.linear(input, self.k_weight, k_bias) + v_proj = F.linear(input, self.v_weight, v_bias) + # The q/k/v projections will be split outside of the QKVParallelLinear. + # Because we are replacing XlaQKVParallelLinear with the + # QKVParallelLinear, we need to concatenate q, k, and v projections to + # match the output shape of the QKVParallelLinear implementation even if + # it seems to be redundant. + # The concat and the following split will be noop, and should be + # optimized away by the compiler. + qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1) + output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \ + self.skip_bias_add else None + if not self.return_bias: + return qkv_proj + return qkv_proj, output_bias + + +def partition_column_parallel_linear(layer: torch.nn.Module, + mesh: xs.Mesh) -> torch.nn.Module: + assert isinstance(layer, ColumnParallelLinear) + xs.mark_sharding(layer.weight, mesh, ('x', None)) + logger.debug("Applied column-parallel sharding to %s", layer) + return layer + + +def partition_row_parallel_linear(layer: torch.nn.Module, + mesh: xs.Mesh) -> torch.nn.Module: + assert isinstance(layer, RowParallelLinear) + xs.mark_sharding(layer.weight, mesh, (None, 'x')) + logger.debug("Applied row-parallel sharding to %s", layer) + return layer + + +def partition_qkv_parallel_linear(layer: torch.nn.Module, + mesh: xs.Mesh) -> torch.nn.Module: + assert isinstance(layer, QKVParallelLinear) + xla_layer = XlaQKVParallelLinear(layer, mesh) + logger.debug("Applied qkv parallel sharding to %s", layer) + return xla_layer + + +MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([ + ("QKVParallelLinear", partition_qkv_parallel_linear), + ("ColumnParallelLinear", partition_column_parallel_linear), + ("RowParallelLinear", partition_row_parallel_linear), +]) + + +def get_fqn(module): + # Get the fully qualified name of the module + return module.__class__.__qualname__ + + +def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None: + """ + Recursively check a PyTorch model and apply appropriate sharding based on + the MODULE_TYPE_TO_WRAPPING_FUNC mapping. + + Args: + model: torch.nn.Module to process + mesh: An XLA SPMD mesh object used for sharding + """ + + def _process_module(module, name=None, parent=None): + for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items(): + if get_fqn(module) == module_type: + wrapped_module = wrapping_func(module, mesh) + + assert parent is not None and name is not None, ( + "Top Level module is not expected to be wrapped.") + if wrapped_module is not module: + # Wrapped module and module are different py object. + # The original module should be replaced by the + # wrapped_module. + logger.debug("replace %s with %s", module, wrapped_module) + setattr(parent, name, wrapped_module) + + module = wrapped_module + break + + for child_name, child_module in list(module.named_children()): + _process_module(child_module, child_name, module) + + _process_module(model) diff --git a/vllm/envs.py b/vllm/envs.py index 44baf5a189b43..3dd0d9045372f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -51,6 +51,7 @@ if TYPE_CHECKING: VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False + VLLM_XLA_USE_SPMD: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_IMAGE_FETCH_TIMEOUT: int = 5 @@ -513,6 +514,10 @@ environment_variables: dict[str, Callable[[], Any]] = { # If set, assert on XLA recompilation after each execution step. "VLLM_XLA_CHECK_RECOMPILATION": lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))), + + # Enable SPMD mode for TPU backend. + "VLLM_XLA_USE_SPMD": + lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py new file mode 100644 index 0000000000000..6197bcdba826b --- /dev/null +++ b/vllm/model_executor/model_loader/tpu.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +import time +from typing import Optional + +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.distributed.spmd as xs + +from vllm.config import ModelConfig, VllmConfig +from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model +from vllm.logger import init_logger +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) + +logger = init_logger(__name__) + + +class TPUModelLoader(DefaultModelLoader): + """ + A TPU model loader for model loading under SPMD mode. + """ + + def load_model( + self, + vllm_config: VllmConfig, + model_config: ModelConfig, + mesh: Optional[xs.Mesh] = None, + ) -> nn.Module: + # Initialize model and load weights on CPU. Then, during SPMD partition, + # weights are sharded and transferred to TPUs. + self.counter_before_loading_weights = time.perf_counter() + model_config = vllm_config.model_config + assert model_config.quantization is None, "Quantization not supported" + target_device = torch.device('cpu') + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + + load_format = vllm_config.load_config.load_format + if load_format != "dummy": + weights_to_load = { + name + for name, _ in model.named_parameters() + } + all_weights = self.get_all_weights(model_config, model) + loaded_weights = model.load_weights(all_weights) + self.counter_after_loading_weights = time.perf_counter() + logger.info( + "Loading weights took %.2f seconds", + self.counter_after_loading_weights - + self.counter_before_loading_weights) + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + if model_config.quantization is None and \ + loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + else: + logger.info("Use dummy weight during weight loading.") + + process_weights_after_loading(model, model_config, target_device) + + counter_before_partition = time.perf_counter() + model = model.eval() + model = model.to('xla') + shard_model(model, mesh) + counter_after_partition = time.perf_counter() + logger.info("Partition model took %.2f seconds", + counter_after_partition - counter_before_partition) + + # Ensure the model is properly loaded. + self._check_model_is_loaded(mesh, model) + + # Need to torch compile after model sharding are done. Because the + # compiler hints ('xs.mark_sharding') are torch ops. + if not model_config.is_multimodal_model: + model.model = torch.compile(model.model, backend="openxla") + else: + model.language_model.model = \ + torch.compile(model.language_model.model, backend="openxla") + return model + + def _check_model_is_loaded(self, mesh: Optional[xs.Mesh], + model: nn.Module) -> None: + """ + Ensure the model is properly loaded. + 1. All model parameters and buffers are on XLA device. + 2. Non-SPMD friendly layers are replaced as expected. + """ + device = xm.xla_device() + device_type = str(device.type) + + # Check parameters + for name, param in model.named_parameters(): + assert param.device.type == device_type, f"Parameter {name} is on \ + {param.device.type} instead of {device_type}" + + # Check buffers + for name, buffer in model.named_buffers(): + assert buffer.device.type == device_type, \ + f"Buffer {name} is on {buffer.device.type} instead of \ + {device_type}" + + for module in model.modules(): + if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'): + raise AssertionError("QKVParallelLinear should be replaced by \ + XlaQKVParallelLinear under SPMD mode.") diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 1b120c3545a56..27cea65217875 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -49,7 +49,9 @@ def _make_synced_weight_loader(original_weight_loader): def _synced_weight_loader(param, *args, **kwargs): original_weight_loader(param, *args, **kwargs) - torch._sync(param) + # torch._sync doesn't support, is not needed for CPU tensors. + if param.device != torch.device("cpu"): + torch._sync(param) return _synced_weight_loader diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5de92351e24ba..c5171b9736b36 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -7,21 +7,22 @@ from unittest.mock import patch import numpy as np import torch -import torch.distributed import torch.nn as nn # TPU XLA related import torch_xla.core.xla_model as xm +import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr import vllm.envs as envs from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.tpu import TPUModelLoader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, PlaceholderRange) @@ -98,6 +99,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): self, vllm_config: VllmConfig, device: torch.device, + original_parallel_config: Optional[ParallelConfig] = None, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -105,6 +107,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config self.parallel_config = vllm_config.parallel_config + self.original_parallel_config = original_parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config @@ -118,6 +121,14 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.device = device self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION + # SPMD Related + self.use_spmd = envs.VLLM_XLA_USE_SPMD + if self.use_spmd: + num_devices = xr.global_runtime_device_count() + mesh_shape = (num_devices, 1) + device_ids = np.array(range(num_devices)) + self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + self.enforce_eager = model_config.enforce_eager self.num_xla_graphs = 0 @@ -271,6 +282,15 @@ class TPUModelRunner(LoRAModelRunnerMixin): max_num_mm_items_decoder_budget) self.max_num_mm_items_by_modality[modality] = max_num_mm_items + if not self.use_spmd: + self.sample_from_logits_func = torch.compile( + self.sample_from_logits, + backend="openxla", + fullgraph=True, + dynamic=False) + else: + self.sample_from_logits_func = self.sample_from_logits + def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: @@ -825,9 +845,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): logits = self.structured_decode(require_struct_decoding, grammar_bitmask_padded, logits, arange) - selected_token_ids = self.sample_from_logits(logits, - tpu_sampling_metadata) - + selected_token_ids = self.sample_from_logits_func( + logits, tpu_sampling_metadata) # NOTE (NickLucche) Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. We can't enforce it due # to recompilations outside torch.compiled code, so just make sure @@ -935,18 +954,26 @@ class TPUModelRunner(LoRAModelRunnerMixin): "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", return_value=xm_tp_rank): - # model = get_model(vllm_config=self.vllm_config) - model_loader = get_model_loader(self.load_config) - if not hasattr(self, "model"): - logger.info("Loading model from scratch...") - model = model_loader.load_model(vllm_config=self.vllm_config, - model_config=self.model_config) + if self.use_spmd: + tpu_loader = TPUModelLoader( + load_config=self.vllm_config.load_config) + model = tpu_loader.load_model( + vllm_config=self.vllm_config, + model_config=self.vllm_config.model_config, + mesh=self.mesh) else: - logger.info( - "Model was already initialized. Loading weights inplace..." - ) - model_loader.load_weights(self.model, - model_config=self.model_config) + # model = get_model(vllm_config=self.vllm_config) + model_loader = get_model_loader(self.load_config) + if not hasattr(self, "model"): + logger.info("Loading model from scratch...") + model = model_loader.load_model( + vllm_config=self.vllm_config, + model_config=self.model_config) + else: + logger.info("Model was already initialized. \ + Loading weights inplace...") + model_loader.load_weights(self.model, + model_config=self.model_config) if self.lora_config is not None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, @@ -970,31 +997,25 @@ class TPUModelRunner(LoRAModelRunnerMixin): device=self.device) else: input_ids = torch.zeros((num_tokens), - dtype=torch.int32, - device=self.device) + dtype=torch.int32).to(self.device) inputs_embeds = None actual_num_reqs = min(num_tokens, self.max_num_reqs) position_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device=self.device) + dtype=torch.int32).to(self.device) slot_mapping = torch.zeros(num_tokens, - dtype=torch.int64, - device=self.device) + dtype=torch.int64).to(self.device) block_tables = torch.zeros( (self.max_num_reqs, self.block_table_cpu.shape[1]), - dtype=torch.int32, - device=self.device) + dtype=torch.int32).to(self.device) query_lens = [1] * self.max_num_reqs query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32).to(self.device) context_lens = torch.ones((self.max_num_reqs, ), - dtype=torch.int32, - device=self.device) + dtype=torch.int32).to(self.device) num_seqs = torch.tensor([actual_num_reqs], - dtype=torch.int32, - device=self.device) + dtype=torch.int32).to(self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, @@ -1198,7 +1219,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): with self.maybe_select_dummy_loras( self.lora_config, np.array([num_reqs], dtype=np.int32)): - self.sample_from_logits(dummy_logits, sampling_metadata) + self.sample_from_logits_func(dummy_logits, + sampling_metadata) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -1332,14 +1354,22 @@ class TPUModelRunner(LoRAModelRunnerMixin): assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): + if self.use_spmd: + num_kv_heads = kv_cache_spec.num_kv_heads + assert self.original_parallel_config is not None + tp_size = \ + self.original_parallel_config.tensor_parallel_size + # TODO: Handle kv cache duplication under SPMD mode. + assert num_kv_heads % tp_size == 0, ( + f"num_kv_heads {num_kv_heads} must be divisible by " + f"tp_size {tp_size} under SPMD mode") kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype tpu_kv_cache = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) + dtype=dtype).to(self.device) kv_caches[layer_name] = tpu_kv_cache else: @@ -1350,6 +1380,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + if self.use_spmd: + # Shard KV Cache + for cache in self.kv_caches: + xs.mark_sharding(cache, self.mesh, (None, 'x', None, None)) + def reset_dynamo_cache(self): if self.is_multimodal_model: compiled_model = self.model.get_language_model().model @@ -1370,7 +1405,9 @@ class TPUModelRunner(LoRAModelRunnerMixin): sample_hidden_states: torch.Tensor) -> torch.Tensor: return self.model.compute_logits(sample_hidden_states, None) - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + # TODO: Under SPMD mode, sample_from_logits has correctness issue. + # Re-enable the torch.compile once the issue is fixed in torchxla. + # @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def sample_from_logits( self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 0707e17afe7a7..bf0a5777cb3ff 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -45,6 +45,15 @@ class TPUWorker: self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config self.parallel_config = vllm_config.parallel_config + self.use_spmd = envs.VLLM_XLA_USE_SPMD + self.original_parallel_config = None + if self.use_spmd: + # Under SPMD mode, distributed env is initialized as if there is + # only one worker/device. + self.original_parallel_config = self.parallel_config + self.parallel_config.tensor_parallel_size = 1 + self.parallel_config.pipeline_parallel_size = 1 + self.parallel_config.world_size = 1 self.scheduler_config = vllm_config.scheduler_config self.device_config = vllm_config.device_config self.speculative_config = vllm_config.speculative_config @@ -95,10 +104,9 @@ class TPUWorker: torch.set_default_dtype(self.model_config.dtype) # Initialize the distributed environment. - init_tpu_worker_distributed_environment(self.parallel_config, - self.rank, - self.distributed_init_method, - self.local_rank) + self._init_tpu_worker_distributed_environment( + self.parallel_config, self.rank, self.distributed_init_method, + self.local_rank) # Device initialization should happen after initializing # the distributed runtime. @@ -132,7 +140,9 @@ class TPUWorker: xr.initialize_cache(per_rank_path, readonly=False) # Init ModelRunner here, so that we have access to self.device. - self.model_runner = TPUModelRunner(self.vllm_config, self.device) + self.model_runner = \ + TPUModelRunner(self.vllm_config, self.device, + self.original_parallel_config) if rank == 0: # If usage stat is enabled, collect relevant info. @@ -147,9 +157,7 @@ class TPUWorker: # Use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value ``None``. - tpu_kv_cache = torch.tensor([], - dtype=dtype, - device=self.device) + tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device) kv_caches[layer_name] = tpu_kv_cache else: raise NotImplementedError( @@ -178,9 +186,20 @@ class TPUWorker: # Get the maximum amount of memory used by the model weights and # intermediate activations. - m = xm.get_memory_info(self.device) - total_memory_size = m["bytes_limit"] - current_mem = m["bytes_used"] + if self.use_spmd: + # This is a workaround for the TPU SPMD mode. The get_memory_info + # API doesn't work with SPMD mode in PyTorch/XLA. + # TODO: use xm.get_memory_info for SPMD once it's supported in + # PyTorch/XLA. + import tpu_info + chip_type, _ = tpu_info.device.get_local_chips() + device_usage = tpu_info.metrics.get_chip_usage(chip_type) + total_memory_size = device_usage[0].total_memory + current_mem = device_usage[0].memory_usage + else: + m = xm.get_memory_info(self.device) + total_memory_size = m["bytes_limit"] + current_mem = m["bytes_used"] # Ideally we would use profiled = m["peak_bytes_used"] to # get weights + activations. But there is memory used during # compilation / weight loading that impacts the peak and @@ -241,28 +260,30 @@ class TPUWorker: # worker will always be healthy as long as it's running. return - -def init_tpu_worker_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = None, - local_rank: int = -1, -) -> None: - """Initialize the distributed environment.""" - - # NOTE(woosuk): This is just to initialize the TP group and broadcast - # the input objects on CPU. The all-reduce and all-gather ops on TPU - # are invoked by `xm.all_reduce` and `xm.all_gather` which use their - # own context. - init_distributed_environment( - world_size=parallel_config.world_size, - rank=rank, - local_rank=local_rank, - distributed_init_method=distributed_init_method, - backend="gloo", - ) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + def _init_tpu_worker_distributed_environment( + self, + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, + ) -> None: + """Initialize the distributed environment.""" + if self.use_spmd: + xr.use_spmd() + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + local_rank=local_rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) try: