From bf1343cc1d258d5df993e022da480c62bf9244f5 Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Fri, 12 Dec 2025 22:08:04 +0000 Subject: [PATCH] Remove torch_xla related code path excluding test files Signed-off-by: Wei-Yu Lin --- .../device_communicators/tpu_communicator.py | 85 - vllm/distributed/tpu_distributed_utils.py | 188 -- vllm/lora/ops/xla_ops/__init__.py | 6 - vllm/lora/ops/xla_ops/lora_ops.py | 141 -- vllm/lora/punica_wrapper/punica_tpu.py | 358 --- vllm/model_executor/layers/fused_moe/layer.py | 5 - .../layers/fused_moe/moe_pallas.py | 83 - .../fused_moe/unquantized_fused_moe_method.py | 52 +- .../layers/quantization/__init__.py | 2 - .../kernels/scaled_mm/__init__.py | 4 - .../quantization/kernels/scaled_mm/xla.py | 106 - .../layers/quantization/tpu_int8.py | 139 -- vllm/model_executor/model_loader/tpu.py | 118 - vllm/usage/usage_lib.py | 18 +- vllm/v1/attention/backends/pallas.py | 59 +- vllm/v1/worker/tpu_model_runner.py | 2191 ----------------- 16 files changed, 3 insertions(+), 3552 deletions(-) delete mode 100644 vllm/distributed/device_communicators/tpu_communicator.py delete mode 100644 vllm/distributed/tpu_distributed_utils.py delete mode 100644 vllm/lora/ops/xla_ops/__init__.py delete mode 100644 vllm/lora/ops/xla_ops/lora_ops.py delete mode 100644 vllm/lora/punica_wrapper/punica_tpu.py delete mode 100644 vllm/model_executor/layers/fused_moe/moe_pallas.py delete mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py delete mode 100644 vllm/model_executor/layers/quantization/tpu_int8.py delete mode 100644 vllm/model_executor/model_loader/tpu.py delete mode 100644 vllm/v1/worker/tpu_model_runner.py diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py deleted file mode 100644 index 9581a3dbc7b74..0000000000000 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ /dev/null @@ -1,85 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os - -import torch -from torch.distributed import ProcessGroup - -from vllm.config import get_current_vllm_config -from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.platforms.tpu import USE_TPU_INFERENCE - -from .base_device_communicator import DeviceCommunicatorBase - -USE_RAY = parallel_config = ( - get_current_vllm_config().parallel_config.distributed_executor_backend == "ray" -) - -logger = init_logger(__name__) - - -class TpuCommunicator(DeviceCommunicatorBase): - def __init__( - self, - cpu_group: ProcessGroup, - device: torch.device | None = None, - device_group: ProcessGroup | None = None, - unique_name: str = "", - ): - super().__init__(cpu_group, device, device_group, unique_name) - - # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node - # must be used together. Therefore, the local rank and world size can - # be simply calculated as follows. - global_rank = self.global_rank - global_world_size = self.global_world_size - - if USE_RAY: - logger.info("TpuCommunicator initialized with RAY") - # Calculate how many TPU nodes are in the current deployment. This - # is the Ray placement group if it is deployed with Ray. Default - # to the number of TPU nodes in the Ray cluster. The number of TPU - # nodes is computed by the total number of TPUs divided by the - # number of TPU accelerators per node, to account for clusters - # with both CPUs and TPUs. - num_nodes = ray_utils.get_num_tpu_nodes() - num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() - if num_nodes_in_pg > 0: - num_nodes = num_nodes_in_pg - - local_world_size = global_world_size // num_nodes - local_rank = global_rank % local_world_size - else: - logger.info("TpuCommunicator initialized with MP") - # Sanity: Verify we run on a single host - num_hosts = torch_xla.tpu.num_tpu_workers() - assert num_hosts == 1 - - # Get the current number of TPUs (we have locally) - local_world_size = torch_xla.tpu.num_available_chips() - - # Get current rank - local_rank = global_rank % local_world_size - - # Ensure environment variables are set for multihost deployments. - # On GKE, this is needed for libtpu and TPU driver to know which TPU - # chip is actually visible. Otherwise the TPU driver will fail to - # initialize because the number of devices would be different from - # the number of visible worker addresses. - os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) - os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) - - pjrt.initialize_multiprocess(local_rank, local_world_size) - xr._init_world_size_ordinal() - self.groups = create_optimized_replica_groups() - - def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - # TODO: Remove the groups specification after XLA compiler can support - # auto-reordering the ring order for all-reduce. - return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups) - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - assert dim == -1, "TPUs only support dim=-1 for all-gather." - return xm.all_gather(input_, dim=dim) diff --git a/vllm/distributed/tpu_distributed_utils.py b/vllm/distributed/tpu_distributed_utils.py deleted file mode 100644 index 4ff1f0ce4410a..0000000000000 --- a/vllm/distributed/tpu_distributed_utils.py +++ /dev/null @@ -1,188 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import OrderedDict -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch_xla.distributed.spmd as xs -from torch.nn.parameter import Parameter - -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) - -logger = init_logger(__name__) - - -class XlaQKVParallelLinear(nn.Module): - def __init__(self, qkv_linear: nn.Module, mesh: Optional["xs.Mesh"] = None): - super().__init__() - assert isinstance(qkv_linear, QKVParallelLinear) - self.skip_bias_add = qkv_linear.skip_bias_add - self.return_bias = qkv_linear.return_bias - assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD." - - self.q_weight: Parameter - self.k_weight: Parameter - self.v_weight: Parameter - self.q_bias: Parameter | None - self.k_bias: Parameter | None - self.v_bias: Parameter | None - self._load_weights_from_qkv_linear(qkv_linear) - if mesh is not None: - self._shard_weight(mesh) - - def _shard_weight(self, mesh: "xs.Mesh"): - self.q_weight = Parameter(self.q_weight.to("xla"), requires_grad=False) - self.k_weight = Parameter(self.k_weight.to("xla"), requires_grad=False) - self.v_weight = Parameter(self.v_weight.to("xla"), requires_grad=False) - xs.mark_sharding(self.q_weight, mesh, ("x", None)) - xs.mark_sharding(self.k_weight, mesh, ("x", None)) - xs.mark_sharding(self.v_weight, mesh, ("x", None)) - if self.q_bias is not None: - assert self.k_bias is not None and self.v_bias is not None, ( - "QKVParallelLinear should have q, k, and v biases together." - ) - self.q_bias = Parameter(self.q_bias.to("xla"), requires_grad=False) - xs.mark_sharding(self.q_bias, mesh, ("x",)) - self.k_bias = Parameter(self.k_bias.to("xla"), requires_grad=False) - xs.mark_sharding(self.k_bias, mesh, ("x",)) - self.v_bias = Parameter(self.v_bias.to("xla"), requires_grad=False) - xs.mark_sharding(self.v_bias, mesh, ("x",)) - - def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module): - q_proj_size, k_proj_size, _ = qkv_linear.output_sizes - # The weight of qkv linear is a concatenation of q, k, and v weights - # along the output dimension. - qkv_weight = qkv_linear.weight.data.cpu() - q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False) - k_weight = Parameter( - qkv_weight[q_proj_size : q_proj_size + k_proj_size], requires_grad=False - ) - v_weight = Parameter( - qkv_weight[q_proj_size + k_proj_size :], requires_grad=False - ) - self.register_parameter("q_weight", q_weight) - self.register_parameter("k_weight", k_weight) - self.register_parameter("v_weight", v_weight) - - if qkv_linear.bias is not None: - q_bias = Parameter(qkv_linear.bias[:q_proj_size], requires_grad=False) - k_bias = Parameter( - qkv_linear.bias[q_proj_size : q_proj_size + k_proj_size], - requires_grad=False, - ) - v_bias = Parameter( - qkv_linear.bias[q_proj_size + k_proj_size :], requires_grad=False - ) - self.register_parameter("q_bias", q_bias) - self.register_parameter("k_bias", k_bias) - self.register_parameter("v_bias", v_bias) - else: - self.register_parameter("q_bias", None) - self.register_parameter("k_bias", None) - self.register_parameter("v_bias", None) - - def forward(self, input): - # Same forward functionality as QKVParallelLinear, but doing qkv porj - # separately. - q_bias = self.q_bias if not self.skip_bias_add else None - k_bias = self.k_bias if not self.skip_bias_add else None - v_bias = self.v_bias if not self.skip_bias_add else None - q_proj = F.linear(input, self.q_weight, q_bias) - k_proj = F.linear(input, self.k_weight, k_bias) - v_proj = F.linear(input, self.v_weight, v_bias) - # The q/k/v projections will be split outside of the QKVParallelLinear. - # Because we are replacing XlaQKVParallelLinear with the - # QKVParallelLinear, we need to concatenate q, k, and v projections to - # match the output shape of the QKVParallelLinear implementation even if - # it seems to be redundant. - # The concat and the following split will be noop, and should be - # optimized away by the compiler. - qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1) - output_bias = ( - torch.cat([q_bias, k_bias, v_bias], dim=-1) if self.skip_bias_add else None - ) - if not self.return_bias: - return qkv_proj - return qkv_proj, output_bias - - -def partition_column_parallel_linear( - layer: torch.nn.Module, mesh: xs.Mesh -) -> torch.nn.Module: - assert isinstance(layer, ColumnParallelLinear) - xs.mark_sharding(layer.weight, mesh, ("x", None)) - logger.debug("Applied column-parallel sharding to %s", layer) - return layer - - -def partition_row_parallel_linear( - layer: torch.nn.Module, mesh: xs.Mesh -) -> torch.nn.Module: - assert isinstance(layer, RowParallelLinear) - xs.mark_sharding(layer.weight, mesh, (None, "x")) - logger.debug("Applied row-parallel sharding to %s", layer) - return layer - - -def partition_qkv_parallel_linear( - layer: torch.nn.Module, mesh: xs.Mesh -) -> torch.nn.Module: - assert isinstance(layer, QKVParallelLinear) - xla_layer = XlaQKVParallelLinear(layer, mesh) - logger.debug("Applied qkv parallel sharding to %s", layer) - return xla_layer - - -MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict( - [ - ("QKVParallelLinear", partition_qkv_parallel_linear), - ("ColumnParallelLinear", partition_column_parallel_linear), - ("RowParallelLinear", partition_row_parallel_linear), - ] -) - - -def get_fqn(module): - # Get the fully qualified name of the module - return module.__class__.__qualname__ - - -def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None: - """ - Recursively check a PyTorch model and apply appropriate sharding based on - the MODULE_TYPE_TO_WRAPPING_FUNC mapping. - - Args: - model: torch.nn.Module to process - mesh: An XLA SPMD mesh object used for sharding - """ - - def _process_module(module, name=None, parent=None): - for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items(): - if get_fqn(module) == module_type: - wrapped_module = wrapping_func(module, mesh) - - assert parent is not None and name is not None, ( - "Top Level module is not expected to be wrapped." - ) - if wrapped_module is not module: - # Wrapped module and module are different py object. - # The original module should be replaced by the - # wrapped_module. - logger.debug("replace %s with %s", module, wrapped_module) - setattr(parent, name, wrapped_module) - - module = wrapped_module - break - - for child_name, child_module in list(module.named_children()): - _process_module(child_module, child_name, module) - - _process_module(model) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py deleted file mode 100644 index b5570ceca68ca..0000000000000 --- a/vllm/lora/ops/xla_ops/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink - -__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py deleted file mode 100644 index 4924890b388cb..0000000000000 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ /dev/null @@ -1,141 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import jax -import jax.numpy as jnp -import torch -import torch.nn.functional as F -import torch_xla.core.xla_builder as xb -from torch.library import impl -from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard - - -@jax.jit -def bgmv_jax(inputs, loras, idxs): - return jnp.einsum( - "td,tX,Xld->tl", - inputs, - jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype), - loras, - ) - - -XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") - - -@impl(XLA_LIB, "bgmv", "XLA") -def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): - if len(loras.shape) == 4: - loras = loras.squeeze(axis=1) - - jax_import_guard() - return xb.call_jax(bgmv_jax, (inputs, loras, idxs)) - - -@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") -def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): - T, _ = inputs.shape - if len(loras.shape) == 4: - loras = loras.squeeze(axis=1) - _, L, _ = loras.shape - - return torch.empty((T, L), device=inputs.device) - - -def bgmv_expand( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True, -): - """ - Args: - inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. - - lora_b_weights (torch.Tensor): LoRA weights of shape - [num_loras, lora_rank, hidden_size]. - - output_tensor (torch.Tensor): output tensor of shape - [num_tokens, hidden_size * num_slices]. - - lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] - indicating which LoRA matrix to use for each token. - add_inputs (bool): Whether or not to add the input tensor to the output - tensor. - """ - - outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - - limit = output_tensor.shape[0] - if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: - limit = 1 - - if output_tensor.shape[1] > outputs.shape[1]: - outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) - - if add_inputs: - return output_tensor + outputs[:limit, : output_tensor.shape[1]] - else: - return outputs[:limit, : output_tensor.shape[1]] - - -def bgmv_shrink( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0, -): - """ - Args: - inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. - lora_b_weights (torch.Tensor): LoRA weights of shape - [num_loras, lora_rank, hidden_size]. - lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] - indicating which LoRA matrix to use for each token. - scaling (float, optional): Scalar multiplier applied to the output. - """ - - return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - - -def bgmv_expand_slice( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True, -): - """ - Args: - inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. - - lora_b_weights (torch.Tensor): LoRA weights of shape - [num_loras, lora_rank, hidden_size]. - - output_tensor (torch.Tensor): output tensor of shape - [num_tokens, hidden_size * num_slices]. - - lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] - indicating which LoRA matrix to use for each token. - add_inputs (bool): Whether or not to add the input tensor to the output - tensor. - """ - outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - - outputs = F.pad( - outputs, - ( - slice_offset, - output_tensor.shape[1] - (slice_offset + slice_size), - 0, - 0, - ), - ) - - if add_inputs: - return output_tensor + outputs - else: - return outputs diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py deleted file mode 100644 index 0888772db54e7..0000000000000 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ /dev/null @@ -1,358 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from typing import TYPE_CHECKING - -import torch -import torch.nn.functional as F -import torch_xla - -from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink -from vllm.lora.punica_wrapper.utils import convert_mapping - -if TYPE_CHECKING: - # avoid circuit import - from vllm.lora.layers import LoRAMapping - -from .punica_base import PunicaWrapperBase - - -class PunicaWrapperTPU(PunicaWrapperBase): - """ - PunicaWrapperTPU is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for - Multi-LoRA, and to provide the interface for the pytorch punica ops. - """ - - def __init__( - self, - max_num_batched_tokens: int, - max_batches: int, - device: torch.device | str, - **kwargs, - ): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - - # PunicaWrapperBase defines some tensors with dtype=torch.int64, which - # isn't supported by the TPU. So convert those tensors to int32. - # Not all of them are used by the TPU so only convert the useful ones. - self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32) - self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) - self._sampler_indices_padded = self._sampler_indices_padded.to( - dtype=torch.int32 - ) - - torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True) - - torch._dynamo.mark_dynamic(self._token_lora_indices, 0) - torch._dynamo.mark_dynamic(self._embeddings_indices, 1) - torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) - - def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: - return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) - - @property - def embeddings_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA. - """ - return self._embeddings_indices[:] - - @property - def sampler_indices_padded(self) -> torch.Tensor: - """ - This property provides access to padded sampler indices. - """ - return self._sampler_indices_padded[:] - - def shrink( - self, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale) - - def expand( - self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool - ): - return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs) - - def expand_slice( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: int, - y_slice_size: int, - add_inputs: bool, - ) -> torch.Tensor: - return bgmv_expand_slice( - x, - w_t_all, - y, - self._get_token_lora_indices(x), - y_offset, - y_slice_size, - add_inputs, - ) - - def add_shrink( - self, - y: tuple[torch.Tensor, ...] | torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - scale: float, - **kwargs, - ) -> torch.Tensor | None: - """ - Performs GEMM for multiple slices of lora_a. - - Semantics: - for i in range(len(lora_a_stacked)): - y[i] += (x @ lora_a_stacked[i]) * scale - - Args: - y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors - x (torch.Tensor): Input tensor - lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights - scale (float): Scaling factor for the operation - """ - - torch.ops.xla.dynamo_set_buffer_donor_(y, True) - x = x.view(-1, x.shape[-1]) - - for slice_idx in range(len(lora_a_stacked)): - lora_s = lora_a_stacked[slice_idx] - y_s = self.shrink(x, lora_s, scale) - y[slice_idx, :, :] = y_s # type: ignore[index] - return y - - def add_expand( - self, - y: torch.Tensor, - x: tuple[torch.Tensor, ...] | torch.Tensor, - lora_b_stacked: tuple[torch.Tensor, ...], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs, - ) -> torch.Tensor: - """ - Performs GEMM for multiple slices of lora_b. - - Semantics: - for i in range(len(lora_b_stacked)): - slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] - offset += slice - - Args: - y (torch.Tensor): Output tensor. - x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors - lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - output_slices (tuple[int, ...]): Every slice's size - add_inputs (bool): Defaults to True. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - offset_left = 0 - - for slice_idx in range(len(lora_b_stacked)): - y = self.expand_slice( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_inputs=add_inputs, - ) - offset_left += output_slices[slice_idx] - return y.view_as(y_org) - - def add_lora_embedding( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs, - ) -> torch.Tensor: - """ - Applies lora specifically for VocabParallelEmbeddingWithLoRA. - - Semantics: - y += x @ lora_b_stacked - - Args: - y (torch.Tensor): Output tensor. - x (torch.Tensor): Input tensor. - lora_b_stacked (torch.Tensor): lora_b's weights. - add_inputs (bool): Default to True. - """ - - # Embedding layer only needs the expand op - return self.expand(y, x, lora_b_stacked, add_inputs) - - def add_lora_linear( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: tuple[torch.Tensor, ...] | None = None, - **kwargs, - ) -> torch.Tensor: - """ - Applicable to linear-related lora. - - Semantics: - for i in range(len(lora_a_stacked)): - y[i] += ( - x[i].unsqueeze(0) - @ lora_a_stacked[indices[i], layer_idx, :, :] - @ lora_b_stacked[indices[i], layer_idx, :, :] - * scale - ).squeeze(0) - - Args: - y (torch.Tensor): Output tensor. Will not be changed in-place. - x (torch.Tensor): Input tensor (T, E) - lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - scale (float): Scaling factor. - output_slices (tuple[int, ...]): Every slice's size. - buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. - """ - - assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - - if buffer is None: - r = lora_b_stacked[0].size(-1) - T = x.size(0) - buffer = torch.zeros( - (len(output_slices), T, r), - dtype=x.dtype, - device=x.device, - ) - buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - return self.add_expand( - y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs - ) - - def add_lora_logits( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: torch.Tensor | None = None, - **kwargs, - ) -> torch.Tensor: - """ - Applies lora specifically for LogitsProcessorWithLoRA. - - Semantics: - buffer = (x @ lora_a_stacked) * scale - y += buffer @ lora_b_stacked - - Args: - y (torch.Tensor): Output tensor. - x (torch.Tensor): Input tensor. - lora_a_stacked (torch.Tensor): lora_a's weights. - lora_b_stacked (torch.Tensor):lora_b's weights. - scale (float): Scaling factor. - buffer (Optional[torch.Tensor]):Default to None. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - x = x.view(-1, x.shape[-1]) - - sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) - buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale) - y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) - return y.view_as(y_org) - - # This performs the same tensor ops as the base method, except it does them - # on the CPU then transfers the results to the TPU - def _update_base_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: list[int | None], - max_loras: int, - vocab_size: int, - ): - # Make sure we don't accidentally collect outside operations - torch_xla.sync() - - # Pad the prompt mapping to avoid running into recompiles on the TPU - # TODO: Should this happen inside mapping internally? If so how can we - # avoid having backend specific LoRAMapping classes? - mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping) - - ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - indices_len, - ) = convert_mapping( - mapping, - lora_index_to_id, - max_loras, - vocab_size, - 0, # extra_vocab_size - "cpu", - ) - self._token_lora_indices = self._pad_to_shape( - base_indices, self._token_lora_indices.shape, dims=1 - ).to(self.device) - self._sampler_indices = self._pad_to_shape( - sampler_indices, self._sampler_indices.shape, dims=1 - ).to(self.device) - self._sampler_indices_padded = self._pad_to_shape( - sampler_indices_padded, self._sampler_indices_padded.shape, dims=1 - ).to(self.device) - self._embeddings_indices = self._pad_to_shape( - embeddings_indices, self._embeddings_indices.shape, dims=2 - ).to(self.device) - self.indices_len[:] = indices_len - - def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: - self.batch_size = 1 - self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[ - : self.batch_size - ] - - def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: - num_reqs = len(prompt_mapping) - - # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular - # import - MIN_NUM_SEQS = 8 - - padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) - pad_len = padded_num_reqs - num_reqs - - padding = [-1] * pad_len - return tuple(list(prompt_mapping) + padding) - - def _pad_to_shape(self, src, target_shape, dims=1): - if dims == 1: - pad_len = target_shape[0] - src.shape[0] - return F.pad(src, (0, pad_len), value=0).to(torch.int32) - else: - pad_rows = target_shape[0] - src.shape[0] - pad_cols = target_shape[1] - src.shape[1] - return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2e7267d56d838..d6226da76eaed 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -71,11 +71,6 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: rocm_aiter_grouped_topk, ) -if current_platform.is_tpu(): - from .moe_pallas import fused_moe as fused_moe_pallas -else: - fused_moe_pallas = None # type: ignore - from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py deleted file mode 100644 index 66c00cf89873a..0000000000000 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch -import torch.nn.functional as F - - -def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: - """ - Compute the histogram of an int32 tensor. The bin edges are defined by the - min and max values, with step = 1. - """ - assert input.dtype == torch.int32, "input must be of torch.int32 dtype." - assert min <= max, "min must be less than or equal to max." - - def searchsorted( - sorted_sequence: torch.Tensor, values_to_search: torch.Tensor - ) -> torch.Tensor: - return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1) - - bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to( - input.device - ) - return searchsorted(bin_edges, input).to(torch.int32) - - -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - global_num_experts: int, - expert_map: torch.Tensor = None, - renormalize: bool = False, -) -> torch.Tensor: - """ - Args: - hidden_states: [*, hidden_size] - w1: [num_experts, intermediate_size * 2, hidden_size] - w2: [num_experts, hidden_size, intermediate_size] - gating_output: [*, num_experts] - """ - assert expert_map is None, "expert_map is not supported for pallas MoE." - import torch_xla.experimental.custom_kernel # noqa: F401 - - orig_shape = hidden_states.shape - hidden_size = hidden_states.shape[-1] - num_tokens = hidden_states.shape[:-1].numel() - num_experts = w1.shape[0] - intermediate_size = w2.shape[-1] - device = hidden_states.device - dtype = hidden_states.dtype - assert (num_tokens * topk) % 16 == 0, ( - "The Pallas GMM kernel requires num_tokens * topk to be a multiple of " - f"16 but got {num_tokens * topk}" - ) - - hidden_states = hidden_states.view(num_tokens, hidden_size) - gating_output = gating_output.view(num_tokens, num_experts) - topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) - topk_weights, topk_indices = topk_weights.topk(topk, dim=-1) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights.to(dtype) - - topk_indices = topk_indices.flatten() - topk_argsort_indices = topk_indices.argsort() - topk_argsort_revert_indices = topk_argsort_indices.argsort() - token_indices = torch.arange(num_tokens, device=device).repeat_interleave(topk) - token_indices = token_indices[topk_argsort_indices] - group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) - - x = hidden_states[token_indices] - x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True) - x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:] - x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True) - x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) - - x = x * topk_weights.unsqueeze(dim=-1) - x = x.sum(dim=-2) - x = x.reshape(orig_shape) - return x diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 82dbccf3fa9da..4c03cff2e8131 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -38,10 +38,6 @@ if current_platform.is_cuda_alike(): else: TritonExperts = None # type: ignore -if current_platform.is_tpu(): - from .moe_pallas import fused_moe as fused_moe_pallas -else: - fused_moe_pallas = None # type: ignore logger = init_logger(__name__) @@ -403,53 +399,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function=layer.custom_routing_function, ) - def forward_tpu( - self, - layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - x: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert not layer.use_grouped_topk - assert layer.num_expert_group is None - assert layer.topk_group is None - assert layer.custom_routing_function is None - assert layer.apply_router_weight_on_input is False - if layer.scoring_func != "softmax": - raise NotImplementedError( - "Only softmax scoring function is supported for TPU." - ) - if layer.e_score_correction_bias is not None: - raise NotImplementedError( - "Expert score correction bias is not supported for TPU." - ) - assert layer.activation == "silu", ( - f"{layer.activation} is not supported for TPU." - ) - assert layer.routed_scaling_factor == 1.0, ( - f"routed_scaling_factor {layer.routed_scaling_factor} is " - "not supported for TPU." - ) - if ( - layer.enable_eplb is not False - or layer.expert_load_view is not None - or layer.logical_to_physical_map is not None - or layer.logical_replica_count is not None - ): - raise NotImplementedError("Expert load balancing is not supported for TPU.") - return fused_moe_pallas( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk=layer.top_k, - gating_output=router_logits, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - renormalize=layer.renormalize, - ) - - if current_platform.is_tpu(): - forward_native = forward_tpu - elif current_platform.is_cpu(): + if current_platform.is_cpu(): forward_native = forward_cpu elif current_platform.is_xpu(): forward_native = forward_xpu diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 18aaae394f935..1a4378f5df3db 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -130,12 +130,10 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .ptpc_fp8 import PTPCFp8Config from .rtn import RTNConfig from .torchao import TorchAOConfig - from .tpu_int8 import Int8TpuConfig method_to_config: dict[str, type[QuantizationConfig]] = { "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, - "tpu_int8": Int8TpuConfig, "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, "fp_quant": FPQuantConfig, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 20d050d387d49..4ccc4182367a6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -19,9 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( TritonScaledMMLinearKernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( - XLAScaledMMLinearKernel, -) from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) @@ -29,7 +26,6 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], - PlatformEnum.TPU: [XLAScaledMMLinearKernel], } diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py deleted file mode 100644 index 0be858c51993d..0000000000000 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ /dev/null @@ -1,106 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import warnings - -import torch -from functorch.experimental.control_flow import cond # noqa: F401 - -from vllm.model_executor.layers.quantization.utils import replace_parameter -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise, -) -from vllm.platforms import current_platform - -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig - - -class XLAScaledMMLinearKernel(ScaledMMLinearKernel): - @classmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: - if not current_platform.is_tpu(): - return False, "Requires TPU." - return True, None - - @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - if not current_platform.is_tpu(): - return False, "ScaledMMXLA requires running on TPU." - - if c.is_static_input_scheme: - return False, "ScaledMMXLA requires dynamic activation scales." - - if not c.input_symmetric: - return False, "ScaledMMXLA requires symmetric activation scales." - - if not c.is_channelwise: - return False, "ScaledMMXLA requires channelwise weight scales" - - return True, None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # WEIGHT - # [out, in] (different than cutlass_scaled_mm) - weight = getattr(layer, self.w_q_name) - replace_parameter( - layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) - ) - - # WEIGHT SCALE - # XLA kernels support only per-tensor and per-channel. - # If we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), convert to the per-channel case. - is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) - if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) - - # [out_channel,] (different than cutlass_scaled_mm) - weight_scale = weight_scale.squeeze(-1) - replace_parameter( - layer, - self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False), - ) - - # Only support symmetric dynamic activation quantization. - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - setattr(layer, self.azp_adj_name, None) - - # Filter warning for cond usage in apply_weights. It is okay - # to specialize the graph since bias is not dynamic. - warnings.filterwarnings( - "ignore", - message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501 - ) - - def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None): - return x - - def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None): - return x + bias - - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - w_q, w_s, _, _, _ = self._get_weight_params(layer) - - # Required to register custom ops. - import torch_xla.experimental.custom_kernel # noqa: F401 - - out = torch.ops.xla.quantized_matmul_int8( - x, - w_q, - w_s, - quantize_activation=True, - ) - - # Explicitly capture control flow to make dynamo happy. - # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 - return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py deleted file mode 100644 index 64bfa8fb80eb2..0000000000000 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn import Module -from torch.nn.parameter import Parameter - -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import ( - QuantizationConfig, - QuantizationMethods, -) -from vllm.model_executor.parameter import ModelWeightParameter - -ACTIVATION_SCHEMES = ["none", "dynamic"] - - -class Int8TpuConfig(QuantizationConfig): - """Int8 Quantization Config class for TPU Backend.""" - - def __init__( - self, - activation_scheme: str = "none", - ) -> None: - super().__init__() - if activation_scheme not in ACTIVATION_SCHEMES: - raise ValueError(f"Unsupported activation scheme {activation_scheme}") - self.activation_scheme = activation_scheme - - def get_name(self) -> QuantizationMethods: - return "tpu_int8" - - def get_supported_act_dtypes(self) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_min_capability(cls) -> int: - raise NotImplementedError("This function should not be called with TPU Backend") - - @staticmethod - def get_config_filenames() -> list[str]: - return [] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig": - activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) - return cls(activation_scheme=activation_scheme) - - def get_quant_method( - self, layer: Module, prefix: str - ) -> Optional["TPUInt8LinearMethod"]: - if isinstance(layer, LinearBase): - return TPUInt8LinearMethod(self) - return None - - -class TPUInt8LinearMethod(LinearMethodBase): - """Int8 Linear method for TPU Quant.""" - - def __init__(self, quant_config: Int8TpuConfig): - self.quant_config = quant_config - self.quantize_activation = False - if self.quant_config.activation_scheme == "dynamic": - self.quantize_activation = True - - def create_weights( - self, - layer: Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight_loader = extra_weight_attrs.get("weight_loader") - weight = ModelWeightParameter( - data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - - def _quantize_weight( - self, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - weight_dtype = weight.dtype - weight = weight.cpu().to(torch.float32) - n_bit = 8 - eps = 1e-5 - max_int = 2 ** (n_bit - 1) - 1 - min_int = -(2 ** (n_bit - 1)) - max_val = weight.abs().amax(dim=-1, keepdim=True) - max_val = max_val.clamp(min=eps) - qscale = max_val / max_int - qweight = torch.clamp( - torch.round(weight * (1.0 / qscale)), min_int, max_int - ).to(torch.int8) - qscale = qscale.squeeze().to(weight_dtype) - return qweight, qscale - - def process_weights_after_loading(self, layer: Module) -> None: - layer.weight = Parameter(layer.weight.data, requires_grad=False) - device = layer.weight.device - qweight, qscale = self._quantize_weight(layer.weight) - qweight = qweight.to(device) - qscale = qscale.to(device) - layer.weight = Parameter(qweight, requires_grad=False) - layer.scale = Parameter(qscale, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - try: - import torch_xla.experimental.custom_kernel # noqa: F401 - except ImportError as err: - raise ImportError( - "Please install torch_xla by following the instructions at " - "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501 - "to run vLLM on TPU." - ) from err - weight = layer.weight - scale = layer.scale - out = torch.ops.xla.quantized_matmul_int8( - x, weight, scale, quantize_activation=self.quantize_activation - ) - if bias is not None: - out = out + bias - return out diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py deleted file mode 100644 index fc142f1f07fae..0000000000000 --- a/vllm/model_executor/model_loader/tpu.py +++ /dev/null @@ -1,118 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time - -import torch -import torch.nn as nn -import torch_xla.core.xla_model as xm -import torch_xla.distributed.spmd as xs - -from vllm.config import ModelConfig, VllmConfig -from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model -from vllm.logger import init_logger -from vllm.model_executor.model_loader.default_loader import DefaultModelLoader -from vllm.model_executor.model_loader.utils import ( - initialize_model, - process_weights_after_loading, -) -from vllm.utils.torch_utils import set_default_torch_dtype - -logger = init_logger(__name__) - - -class TPUModelLoader(DefaultModelLoader): - """ - A TPU model loader for model loading under SPMD mode. - """ - - def load_model( - self, - vllm_config: VllmConfig, - model_config: ModelConfig, - mesh: xs.Mesh | None = None, - ) -> nn.Module: - # Initialize model and load weights on CPU. Then, during SPMD partition, - # weights are sharded and transferred to TPUs. - self.counter_before_loading_weights = time.perf_counter() - model_config = vllm_config.model_config - assert model_config.quantization is None, "Quantization not supported" - target_device = torch.device("cpu") - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = initialize_model(vllm_config=vllm_config) - - load_format = vllm_config.load_config.load_format - if load_format != "dummy": - weights_to_load = {name for name, _ in model.named_parameters()} - all_weights = self.get_all_weights(model_config, model) - loaded_weights = model.load_weights(all_weights) - self.counter_after_loading_weights = time.perf_counter() - logger.info( - "Loading weights took %.2f seconds", - self.counter_after_loading_weights - - self.counter_before_loading_weights, - ) - # We only enable strict check for non-quantized models - # that have loaded weights tracking currently. - if model_config.quantization is None and loaded_weights is not None: - weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - raise ValueError( - "Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}" - ) - else: - logger.info("Use dummy weight during weight loading.") - - process_weights_after_loading(model, model_config, target_device) - - counter_before_partition = time.perf_counter() - model = model.eval() - model = model.to("xla") - shard_model(model, mesh) - counter_after_partition = time.perf_counter() - logger.info( - "Partition model took %.2f seconds", - counter_after_partition - counter_before_partition, - ) - - # Ensure the model is properly loaded. - self._check_model_is_loaded(mesh, model) - - # Need to torch compile after model sharding are done. Because the - # compiler hints ('xs.mark_sharding') are torch ops. - if not model_config.is_multimodal_model: - model.model = torch.compile(model.model, backend="openxla") - else: - model.language_model.model = torch.compile( - model.language_model.model, backend="openxla" - ) - return model - - def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None: - """ - Ensure the model is properly loaded. - 1. All model parameters and buffers are on XLA device. - 2. Non-SPMD friendly layers are replaced as expected. - """ - device = xm.xla_device() - device_type = str(device.type) - - # Check parameters - for name, param in model.named_parameters(): - assert param.device.type == device_type, ( - f"Parameter {name} is on {param.device.type} instead of {device_type}" - ) - - # Check buffers - for name, buffer in model.named_buffers(): - assert buffer.device.type == device_type, ( - f"Buffer {name} is on {buffer.device.type} instead of {device_type}" - ) - - for module in model.modules(): - if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"): - raise AssertionError( - "QKVParallelLinear should be replaced by \ - XlaQKVParallelLinear under SPMD mode." - ) diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 69226763aafe6..b0886bba8a22a 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -186,20 +186,6 @@ class UsageMessage: except Exception: return False - def _report_torch_xla_usage(self) -> bool: - try: - import torch_xla - - self.gpu_count = torch_xla.runtime.world_size() - self.gpu_type = torch_xla.tpu.get_tpu_type() - self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[ - "bytes_limit" - ] - self.cuda_runtime = "torch_xla" - return True - except Exception: - return False - def _report_usage_once( self, model_architecture: str, @@ -217,9 +203,7 @@ class UsageMessage: if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda if current_platform.is_tpu(): # noqa: SIM102 - if (not self._report_tpu_inference_usage()) and ( - not self._report_torch_xla_usage() - ): + if not self._report_tpu_inference_usage(): logger.exception("Failed to collect TPU information") self.provider = _detect_cloud_provider() self.architecture = platform.machine() diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 5f7f8e81a24c4..e5a0cf7420497 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -35,7 +35,7 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = { import tpu_inference # noqa: F401 - +# Note(weiyulin): some static functions are still used by tpu-inference class PallasAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: @@ -314,60 +314,3 @@ def write_to_kv_cache( ) # NOTE: the in-place copy will be optimized away by XLA compiler. kv_cache.copy_(new_kv_cache) - - -# We can move this function to a common utils file if it's also useful for other -# hardware. -def dtype_bits(dtype: torch.dtype): - if dtype.is_floating_point: - try: - return torch.finfo(dtype).bits - except TypeError: - pass - elif dtype.is_complex: - if dtype is torch.complex32: - return 32 - elif dtype is torch.complex64: - return 64 - elif dtype is torch.complex128: - return 128 - else: - try: - return torch.iinfo(dtype).bits - # torch.iinfo cannot support int4, int2, bits8... - except TypeError: - pass - str_dtype = str(dtype) - # support torch.int4, torch.int5, torch.uint5... - if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"): - return int(str_dtype[-1]) - raise TypeError(f"Getting the bit width of {dtype} is not supported") - - -def get_dtype_packing(dtype): - bits = dtype_bits(dtype) - if 32 % bits != 0: - raise ValueError( - f"The bit width must be divisible by 32, but got bits={bits}, " - "dtype={dtype}" - ) - return 32 // bits - - -def get_page_size_bytes( - block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype -) -> int: - """Returns the size in bytes of one page of the KV cache.""" - padded_head_size = ( - cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - ) - num_combined_kv_heads = num_kv_heads * 2 - - # NOTE: for the implicit padding in XLA - packing = get_dtype_packing(kv_cache_dtype) - num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing - - kv_cache_dtype_bits = dtype_bits(kv_cache_dtype) - return ( - block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8 - ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py deleted file mode 100644 index c7404c4642d7e..0000000000000 --- a/vllm/v1/worker/tpu_model_runner.py +++ /dev/null @@ -1,2191 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import bisect -import gc -import time -from typing import TYPE_CHECKING, Any, cast -from unittest.mock import patch - -import numpy as np -import torch -import torch.nn as nn - -# TPU XLA related -import torch_xla -import torch_xla.core.xla_model as xm -import torch_xla.distributed.spmd as xs -import torch_xla.runtime as xr - -import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionType -from vllm.attention.layer import Attention, MLAAttention -from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention -from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper -from vllm.config import ( - ParallelConfig, - VllmConfig, - get_layers_from_vllm_config, - update_config, -) -from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group -from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks -from vllm.forward_context import set_forward_context -from vllm.logger import init_logger -from vllm.lora.layers import BaseLayerWithLoRA -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.tpu import TPUModelLoader -from vllm.model_executor.models.interfaces import ( - SupportsMultiModal, - supports_transcription, -) -from vllm.model_executor.models.interfaces_base import ( - is_pooling_model, - is_text_generation_model, -) -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import ( - BatchedTensorInputs, - MultiModalKwargsItem, - PlaceholderRange, -) -from vllm.multimodal.utils import group_mm_kwargs_by_modality -from vllm.sequence import IntermediateTensors -from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils.math_utils import cdiv, prev_power_of_2 -from vllm.utils.platform_utils import is_pin_memory_available -from vllm.v1.attention.backends.pallas import ( - TPU_STR_DTYPE_TO_TORCH_DTYPE, - PallasAttentionBackend, - PallasMetadata, - get_page_size_bytes, -) -from vllm.v1.kv_cache_interface import ( - AttentionSpec, - FullAttentionSpec, - KVCacheConfig, - KVCacheSpec, - MLAAttentionSpec, - SlidingWindowSpec, -) -from vllm.v1.outputs import ( - EMPTY_MODEL_RUNNER_OUTPUT, - LogprobsLists, - LogprobsTensors, - ModelRunnerOutput, -) -from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata -from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler -from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin, - KVConnectorOutput, -) -from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch - -from .utils import ( - MultiModalBudget, - add_kv_sharing_layers_to_kv_cache_groups, - bind_kv_cache, - sanity_check_mm_encoder_outputs, -) - -if TYPE_CHECKING: - from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput - -logger = init_logger(__name__) - -INVALID_TOKEN_ID = -1 -# Smallest output size -MIN_NUM_SEQS = 8 - - -######################################################### -# Ways to avoid recompilation -######################################################### -# -# The model executor has two primary components: -# 1. preparing the model and sampler inputs -# 2. executing the model and sampler. -# The core idea is to avoid any TPU computation during input preparation. For -# better compilation tracking and increased flexibility, the model execution and -# sampler are divided into several distinct components. -# -# Below are the detailed steps: -# -# Step 1 -# It is recommended to avoid TPU operations when preparing the model and sampler -# inputs. CPU tensors can be prepared and transferred to the XLA device using -# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids -# compilation. -# -# Step 2 -# The TPU execution should be decomposed into subgraphs (4 at the moment): -# 1. the main model -# 2. selecting hidden states for each request -# 3. sampler -# 4. encoder. -# Each subgraph should be decorated in a torch.compile. This is used to make -# sure that we have the same subgraph topology in both dummy_run and -# xecute_model. The results from these subgraphs should either be passed to -# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for -# subsequent processing on the CPU. -# -# Step 3 -# The dummy_run should be comprehensive, ensuring all potential input shapes and -# branch predictions are included as subgraph inputs to facilitate -# pre-compilation. -class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - original_parallel_config: ParallelConfig | None = None, - ): - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.original_parallel_config = original_parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - self.device_config = vllm_config.device_config - - model_config = self.model_config - cache_config = self.cache_config - scheduler_config = self.scheduler_config - parallel_config = self.parallel_config - self.device = device - self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION - - # SPMD Related - self.use_spmd = envs.VLLM_XLA_USE_SPMD - if self.use_spmd: - num_devices = xr.global_runtime_device_count() - mesh_shape = (num_devices, 1) - device_ids = np.array(range(num_devices)) - self.mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y")) - - self.enforce_eager = model_config.enforce_eager - - self.num_xla_graphs = 0 - self._update_num_xla_graphs("init") - - self.pin_memory = is_pin_memory_available() - self.dtype = self.model_config.dtype - if cache_config.cache_dtype == "auto": - model_dtype = self.dtype - if isinstance(model_dtype, str): - self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype] - else: - self.kv_cache_dtype = model_dtype - else: - self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - self._hidden_states_dtype = self.dtype - - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_model_len = model_config.max_model_len - self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.num_blocks_per_most_len_req = ( - cdiv(self.most_model_len, self.block_size) - if self.most_model_len is not None - else None - ) - # InputBatch needs to work with sampling tensors greater than padding - # to avoid dynamic shapes. Also, avoid suboptimal alignment. - self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) - self.num_tokens_paddings = _get_token_paddings( - min_token_size=16, - max_token_size=scheduler_config.max_num_batched_tokens, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP, - ) - # In case `max_num_tokens < max(num_tokens_paddings)` use the actual - # padded max value to pre-allocate data structures and pre-compile. - self.max_num_tokens = self.num_tokens_paddings[-1] - - # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, "attention" - ) - self.num_query_heads = model_config.get_num_attention_heads(parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() - self.inputs_embeds_size = model_config.get_inputs_embeds_size() - self.vocab_size = model_config.get_vocab_size() - - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.uses_mrope = model_config.uses_mrope - self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config - ) - # TODO: Support M-RoPE (e.g, Qwen2-VL) - assert not self.uses_mrope, "TPU does not support M-RoPE yet." - - self._num_slices_per_kv_cache_update_block = ( - _get_num_slices_per_kv_cache_update_block( - get_page_size_bytes( - block_size=self.block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - kv_cache_dtype=self.kv_cache_dtype, - ) - ) - ) - - # Lazy initialization - self.model: nn.Module # Set after load_model - self.kv_caches: list[torch.Tensor] = [] - # mm_hash -> encoder_output - self.encoder_cache: dict[str, torch.Tensor] = {} - - # Request states. - self.requests: dict[str, CachedRequestState] = {} - # NOTE(rob): num_prompt_logprobs only includes reqs - # that are currently in the prefill phase. - self.num_prompt_logprobs: dict[str, int] = {} - - # Initialize input batch early to avoid AttributeError in _update_states - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - block_sizes=[self.block_size], - kernel_block_sizes=[self.cache_config.block_size], - ) - - # Cached torch/numpy tensor - # The pytorch tensor and numpy array share the same buffer. - # Sometimes the numpy op is faster so we create both. - self.input_ids_cpu = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device="cpu" - ) - - self.positions_cpu = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device="cpu" - ) - self.positions_np = self.positions_cpu.numpy() - self.block_table_cpu = torch.zeros( - (self.max_num_reqs, self.max_num_blocks_per_req), - dtype=torch.int32, - device="cpu", - ) - # adjust num_reqs to avoid SMEM OOM. - self.num_reqs_most_model_len = ( - min( - PallasAttentionBackend.get_max_num_seqs( - self.most_model_len, self.block_size - ), - self.max_num_reqs, - ) - if self.most_model_len is not None - else None - ) - self.num_reqs_max_model_len = min( - PallasAttentionBackend.get_max_num_seqs( - self.max_model_len, self.block_size - ), - self.max_num_reqs, - ) - self.query_start_loc_cpu = torch.zeros( - self.max_num_tokens + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) - self.query_start_loc_np = self.query_start_loc_cpu.numpy() - - self.seq_lens_cpu = torch.zeros( - self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) - self.seq_lens_np = self.seq_lens_cpu.numpy() - - # Only relevant for multimodal models - if self.supports_mm_inputs: - self.is_mm_embed_cpu = torch.zeros( - self.max_num_tokens, - dtype=torch.bool, - device="cpu", - pin_memory=self.pin_memory, - ) - - # Range tensor with values [0 .. self.max_num_tokens - 1]. - # Used to initialize positions / context_lens / seq_lens - # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64) - self.num_reqs_paddings = _get_req_paddings( - min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs - ) - - # Layer pairings for cross-layer KV sharing. - # If an Attention layer `layer_name` is in the keys of this dict, it - # means this layer will perform attention using the keys and values - # from the KV cache of `shared_kv_cache_layers[layer_name]`. - self.shared_kv_cache_layers: dict[str, str] = {} - - # tensors for structured decoding - self.grammar_bitmask_cpu = torch.zeros( - (self.max_num_reqs, cdiv(self.vocab_size, 32)), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) - self.require_structured_out_cpu = torch.zeros( - (self.max_num_reqs, 1), - dtype=torch.bool, - device="cpu", - pin_memory=self.pin_memory, - ) - self.structured_decode_arange = torch.arange( - 0, 32, device="cpu", pin_memory=self.pin_memory - ) - - self.mm_budget = ( - MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) - if self.supports_mm_inputs - else None - ) - - if not self.use_spmd: - self.sample_from_logits_func = torch.compile( - self.sample_from_logits, - backend="openxla", - fullgraph=True, - dynamic=False, - ) - else: - self.sample_from_logits_func = self.sample_from_logits - - # For passing scheduler_output between successive - # execute_model() and sample_tokens() calls. - self.scheduler_output: SchedulerOutput | None = None - self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None - - def reset_mm_cache(self) -> None: - if self.mm_budget: - self.mm_budget.reset_cache() - - def _update_num_xla_graphs(self, case_str): - check_comp = self.check_recompilation and not self.enforce_eager - if not check_comp: - return - - total_cached_graphs = xr.get_num_cached_compilation_graph() - new_compiled_graphs = total_cached_graphs - self.num_xla_graphs - if new_compiled_graphs == 0: - return - - logger.info( - "Add new %d compiled XLA graphs due to %s", new_compiled_graphs, case_str - ) - self.num_xla_graphs += new_compiled_graphs - - def _verify_num_xla_graphs(self, case_str): - check_comp = self.check_recompilation and not self.enforce_eager - if not check_comp: - return - - curr_cached_graph = xr.get_num_cached_compilation_graph() - assert self.num_xla_graphs == curr_cached_graph, ( - "Recompilation after warm up is detected during {}." - " num_xla_graphs = {} curr_cached_graph = {}".format( - case_str, self.num_xla_graphs, curr_cached_graph - ) - ) - - def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: - """Update the cached states and the persistent batch with the scheduler - output. - - The updated states are used by the `_prepare_inputs` function to create - the input GPU tensors for the model. - - Returns: - True if there is a new/resumed/paused/finished request. - If False, we can skip copying SamplingMetadata to the GPU. - """ - # Remove finished requests from the cached states. - for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - self.num_prompt_logprobs.pop(req_id, None) - - # Remove the finished requests from the persistent batch. - # NOTE(woosuk): There could be an edge case where finished_req_ids and - # scheduled_req_ids overlap. This happens when a request is aborted and - # then resubmitted with the same ID. In this case, we treat them as two - # distinct requests - clearing the cached states for the first request - # and handling the second as a new request. - removed_req_indices: list[int] = [] - for req_id in scheduler_output.finished_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) - - # Free the cached encoder outputs. - for mm_hash in scheduler_output.free_encoder_mm_hashes: - self.encoder_cache.pop(mm_hash, None) - - # Remove the unscheduled requests from the persistent batch. - # NOTE(woosuk): The unscheduled requests are either preempted requests - # or running requests that are not scheduled in this step. We remove - # them from the persistent batch but keep their cached states since - # they will be scheduled again sometime in the future. - scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() - cached_req_ids = self.input_batch.req_id_to_index.keys() - unscheduled_req_ids = cached_req_ids - scheduled_req_ids - # NOTE(woosuk): The persistent batch optimization assumes that - # consecutive batches contain mostly the same requests. If batches - # have low request overlap (e.g., alternating between two distinct - # sets of requests), this optimization becomes very inefficient. - for req_id in unscheduled_req_ids: - req_index = self.input_batch.remove_request(req_id) - assert req_index is not None - removed_req_indices.append(req_index) - - req_ids_to_add: list[str] = [] - # Add new requests to the cached states. - for new_req_data in scheduler_output.scheduled_new_reqs: - assert new_req_data.sampling_params is not None, ( - "Pooling is not supported in TPU yet" - ) - req_id = new_req_data.req_id - sampling_params = new_req_data.sampling_params - - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - prompt_embeds=new_req_data.prompt_embeds, - mm_features=new_req_data.mm_features, - sampling_params=sampling_params, - pooling_params=None, - generator=None, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - lora_request=new_req_data.lora_request, - ) - - if sampling_params and sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[req_id] = ( - self.input_batch.vocab_size - if sampling_params.prompt_logprobs == -1 - else sampling_params.prompt_logprobs - ) - - req_ids_to_add.append(req_id) - - # Update the states of the running/resumed requests. - req_data = scheduler_output.scheduled_cached_reqs - for i, req_id in enumerate(req_data.req_ids): - req_state = self.requests[req_id] - num_computed_tokens = req_data.num_computed_tokens[i] - new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_id in req_data.resumed_req_ids - - # Update the cached states. - req_state.num_computed_tokens = num_computed_tokens - if not resumed_from_preemption: - if new_block_ids is not None: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): - block_ids.extend(new_ids) - else: - assert new_block_ids is not None - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = new_block_ids - - req_index = self.input_batch.req_id_to_index.get(req_id) - if req_index is None: - # The request is not in the persistent batch. - # The request was either preempted and resumed later, or was not - # scheduled in the previous step and needs to be added again. - req_ids_to_add.append(req_id) - continue - - # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens - if new_block_ids is not None: - self.input_batch.block_table.append_row(new_block_ids, req_index) - - # Add the new or resumed requests to the persistent batch. - # The smaller empty indices are filled first. - removed_req_indices = sorted(removed_req_indices, reverse=True) - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - # Fill the empty index or append to the end - req_index = removed_req_indices.pop() if removed_req_indices else None - self.input_batch.add_request(req_state, req_index) - - # Condense the batched states if there are empty indices. - if removed_req_indices: - self.input_batch.condense(removed_req_indices) - - return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 - - def get_model(self) -> nn.Module: - return self.model - - def get_supported_generation_tasks(self) -> list[GenerationTask]: - model = self.get_model() - supported_tasks = list[GenerationTask]() - - if is_text_generation_model(model): - supported_tasks.append("generate") - - if supports_transcription(model): - if model.supports_transcription_only: - return ["transcription"] - - supported_tasks.append("transcription") - - return supported_tasks - - def get_supported_pooling_tasks(self) -> list[PoolingTask]: - model = self.get_model() - if not is_pooling_model(model): - return [] - - return list(model.pooler.get_supported_tasks()) - - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - tasks = list[SupportedTask]() - - if self.model_config.runner_type == "generate": - tasks.extend(self.get_supported_generation_tasks()) - if self.model_config.runner_type == "pooling": - tasks.extend(self.get_supported_pooling_tasks()) - - return tuple(tasks) - - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: - """ - Generates the KVCacheSpec by parsing the kv cache format from each - Attention module in the static forward context. - Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache - format. Layers that do not need KV cache are not included. - """ - - layers = get_layers_from_vllm_config( - self.vllm_config, - AttentionLayerBase, # type: ignore[type-abstract] - ) - block_size = self.vllm_config.cache_config.block_size - cache_dtype_str = self.vllm_config.cache_config.cache_dtype - - kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in layers.items(): - # Classic Attention path - if isinstance(attn_module, Attention): - if ( - kv_tgt_layer := attn_module.kv_sharing_target_layer_name - ) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue - - if attn_module.attn_type == AttentionType.DECODER: - if isinstance(attn_module, ChunkedLocalAttention): - logger.warning_once( - "Using irope in Pallas is not supported yet, it " - "will fall back to global attention for long context." - ) - if attn_module.sliding_window is not None: - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - ) - else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - ) - elif attn_module.attn_type in ( - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - ): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError - else: - raise ValueError(f"Unknown attention type: {attn_module.attn_type}") - # MLAAttention path - elif isinstance(attn_module, MLAAttention): - if layer_name in kv_cache_spec: - continue - kv_cache_spec[layer_name] = MLAAttentionSpec( - block_size=block_size, - num_kv_heads=1, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - cache_dtype_str=cache_dtype_str, - ) - else: - continue - - return kv_cache_spec - - def _get_slot_mapping_metadata( - self, num_reqs, num_scheduled_tokens_per_req - ) -> np.ndarray: - """ - Computes metadata for mapping slots to blocks in the key-value (KV) - cache for a batch of requests. - - This function determines, for each request in the batch, how the - scheduled tokens are distributed across memory blocks, and generates - metadata needed to map slices of tokens to their corresponding positions - in the KV cache. - - Args: - num_reqs (int): Number of requests in the current batch. - num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens - to be scheduled for each request. - - Returns: - np.ndarray: A 2D array of shape (total_block_len, 3), where each row - contains: - - kv_cache_start_index (int): The starting index in the KV cache - for the corresponding slice. - - new_kv_start_index (int): The starting index in the new KV - cache for the corresponding slice. - - slice_len (int): The length of the slice. - """ - slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] - slices_end = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] - + num_scheduled_tokens_per_req - ) - local_block_start_idx = slices_start // self.block_size - local_block_end_idx = (slices_end - 1) // self.block_size - no_repeat_req_indices = self.arange_np[:num_reqs] - global_block_start_idx = ( - no_repeat_req_indices * self.max_num_blocks_per_req + local_block_start_idx - ) - block_lens = local_block_end_idx - local_block_start_idx + 1 - global_block_start_idx = np.repeat(global_block_start_idx, block_lens) - slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens]) - global_block_indices = global_block_start_idx + slice_arange - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[global_block_indices].numpy() - total_block_len = np.sum(block_lens) - slot_mapping_slices = np.repeat( - np.array([[0, self.block_size]], dtype=np.int32), total_block_len, axis=0 - ) - cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32) - np.cumsum(block_lens, out=cu_block_lens[1:]) - for req_idx in range(num_reqs): - slot_mapping_slices[cu_block_lens[req_idx]][0] = ( - slices_start[req_idx] % self.block_size - ) - slot_mapping_slices[cu_block_lens[req_idx + 1] - 1][1] = ( - slices_end[req_idx] - 1 - ) % self.block_size + 1 - slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0] - cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32) - np.cumsum(slice_lens, out=cu_slices_lens[1:]) - kv_cache_start_indices = slot_mapping_slices[:, 0] + ( - block_numbers * self.block_size - ) - new_kv_start_indices = cu_slices_lens[:-1] - slot_mapping_metadata = np.stack( - [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1 - ) - return slot_mapping_metadata - - def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int): - assert scheduler_output.total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - assert start_index < num_reqs - - # Get the number of scheduled tokens for each request. - use_max_model_len = self.most_model_len is None - num_scheduled_tokens_per_req = [] - max_num_scheduled_tokens_all_reqs = 0 - end_index = start_index - - # Use either most_model_len or max_model_len depending on request size. - for i in range(start_index, num_reqs): - req_id = self.input_batch.req_ids[i] - assert req_id is not None - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - if ( - not use_max_model_len - and self.most_model_len is not None - and num_tokens > self.most_model_len - ): - use_max_model_len = True - num_scheduled_tokens_per_req.append(num_tokens) - if use_max_model_len: - if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len: - num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ - : self.num_reqs_max_model_len - ] - end_index = start_index + self.num_reqs_max_model_len - else: - end_index = num_reqs - else: - assert self.num_reqs_most_model_len is not None - if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len: - num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ - : self.num_reqs_most_model_len - ] - end_index = start_index + self.num_reqs_most_model_len - else: - end_index = num_reqs - max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req) - num_scheduled_tokens_per_req = np.array( - num_scheduled_tokens_per_req, dtype=np.int32 - ) - total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req) - assert max_num_scheduled_tokens_all_reqs > 0 - - num_reqs = len(num_scheduled_tokens_per_req) - - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - # For each scheduled token, what are the corresponding req index. - req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req) - - # Get batched arange. - # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # For each scheduled token, what is its position in corresponding req. - arange = np.concatenate( - [self.arange_np[:n] for n in num_scheduled_tokens_per_req] - ) - - # Get positions. - positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np, - ) - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = ( - positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] - ) - - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - torch.index_select( - self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens], - ) - - # Prepare the attention metadata. - self.query_start_loc_np[0] = 0 - np.cumsum( - num_scheduled_tokens_per_req, out=self.query_start_loc_np[1 : num_reqs + 1] - ) - self.query_start_loc_np[num_reqs + 1 :] = 1 - - self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] - + num_scheduled_tokens_per_req - ) - - # Do the padding and copy the tensors to the TPU. - padded_total_num_scheduled_tokens = _get_padded_token_len( - self.num_tokens_paddings, total_num_scheduled_tokens - ) - # Zero out to avoid spurious values from prev iteration (last cp chunk) - self.input_ids_cpu[ - total_num_scheduled_tokens:padded_total_num_scheduled_tokens - ] = 0 - self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to( - self.device - ) - self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to( - self.device - ) - if use_max_model_len: - block_tables = self.block_table_cpu[ - : self.num_reqs_max_model_len, : self.max_num_blocks_per_req - ] - block_tables[:num_reqs, : self.max_num_blocks_per_req] = ( - self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs] - ) - query_start_loc = self.query_start_loc_cpu[ - : self.num_reqs_max_model_len + 1 - ].to(self.device) - seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device) - else: - assert self.num_reqs_most_model_len is not None - block_tables = self.block_table_cpu[ - : self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req - ] - block_tables[:num_reqs, : self.num_blocks_per_most_len_req] = ( - self.input_batch.block_table[0].get_cpu_tensor()[ - :num_reqs, : self.num_blocks_per_most_len_req - ] - ) - query_start_loc = self.query_start_loc_cpu[ - : self.num_reqs_most_model_len + 1 - ].to(self.device) - seq_lens = self.seq_lens_cpu[: self.num_reqs_most_model_len].to(self.device) - block_tables = block_tables.to(self.device) - - # Calculate the slot mapping - slot_mapping_metadata = self._get_slot_mapping_metadata( - num_reqs, num_scheduled_tokens_per_req - ) - num_kv_update_slices = slot_mapping_metadata.shape[0] - padded_num_slices = _get_padded_num_kv_cache_update_slices( - padded_total_num_scheduled_tokens, self.max_num_reqs, self.block_size - ) - slot_mapping_metadata = np.pad( - slot_mapping_metadata, - [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], - constant_values=0, - ) - slot_mapping_metadata = np.transpose(slot_mapping_metadata) - slot_mapping_metadata = torch.tensor(slot_mapping_metadata, device=self.device) - - if self.lora_config is not None: - # We need to respect padding when activating LoRA adapters - padded_num_scheduled_tokens_per_req = np.copy( - num_scheduled_tokens_per_req - ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += ( - padded_total_num_scheduled_tokens - total_num_scheduled_tokens - ) - - self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) - - attn_metadata = PallasMetadata( - slot_mapping=slot_mapping_metadata, - block_tables=block_tables, - context_lens=seq_lens, - query_start_loc=query_start_loc, - num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), - num_kv_update_slices=torch.tensor( - [num_kv_update_slices], dtype=torch.int32, device=self.device - ), - num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block, - ) - # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial - # request in the batch. While we should not sample any token from this - # partial request, we do so for simplicity. We will ignore the sampled - # token from the partial request. - # TODO: Support prompt logprobs. - padded_num_reqs = _get_padded_num_reqs_with_upper_limit( - num_reqs, self.max_num_reqs - ) - # Indices at which we sample (positions of last token in the sequence). - # Padded to avoid recompiling when `num_reqs` varies. - logits_indices = self.query_start_loc_cpu[1 : padded_num_reqs + 1] - 1 - logits_indices = logits_indices.to(self.device) - - if self.lora_config is not None: - # We need to respect padding when activating LoRA adapters - padded_num_scheduled_tokens_per_req = np.copy( - num_scheduled_tokens_per_req - ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += ( - padded_total_num_scheduled_tokens - total_num_scheduled_tokens - ) - - self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) - - layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() - per_layer_attn_metadata = { - layer_name: attn_metadata for layer_name in layer_names - } - return ( - per_layer_attn_metadata, - logits_indices, - padded_num_reqs, - num_reqs, - end_index, - ) - - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): - scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs - if not scheduled_encoder_inputs: - return - - # Batch the multi-modal inputs. - mm_kwargs = list[MultiModalKwargsItem]() - # List of tuple (mm_hash, pos_info) - mm_hashes_pos = list[tuple[str, PlaceholderRange]]() - for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): - req_state = self.requests[req_id] - - for mm_input_id in encoder_input_ids: - mm_feature = req_state.mm_features[mm_input_id] - if mm_feature.data is None: - continue - mm_hash = mm_feature.identifier - mm_kwargs.append(mm_feature.data) - mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) - - # Batch mm inputs as much as we can: if a request in the batch has - # multiple modalities or a different modality than the previous one, - # we process it separately to preserve item order. - # FIXME(ywang96): This is a hacky way to deal with multiple modalities - # in the same batch while still being able to benefit from batching - # multimodal inputs. The proper solution should be reordering the - # encoder outputs. - model = cast(SupportsMultiModal, self.model) - encoder_outputs = [] - for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - ): - # Run the encoder. - # `curr_group_outputs` is either of the following: - # 1. A tensor of shape (num_items, feature_size, hidden_size) - # in case feature_size is fixed across all multimodal items. - # 2. A list or tuple (length: num_items) of tensors, each of shape - # (feature_size, hidden_size) in case the feature size is dynamic - # depending on the input multimodal items. - torch_xla.sync(wait=False) - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) - torch_xla.sync(wait=False) - - sanity_check_mm_encoder_outputs( - curr_group_outputs, - expected_num_items=num_items, - ) - - if isinstance(curr_group_outputs, torch.Tensor): - encoder_outputs.append(curr_group_outputs) - else: - assert isinstance(curr_group_outputs, (list, tuple)) - for output in curr_group_outputs: - encoder_outputs.append(output) - - # Cache the encoder outputs. - # NOTE (NickLucche) here we diverge from logic in other runners, as we - # assume to only have whole mm items to process. Hence we avoid the - # intrinsic dynamism that `scatter_mm_placeholders` introduces. - for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): - assert pos_info.is_embed is None, ( - "Expected all positions to be contiguous and embeddings." - ) - self.encoder_cache[mm_hash] = output - - def _gather_mm_embeddings( - self, - scheduler_output: "SchedulerOutput", - ) -> tuple[list[torch.Tensor], torch.Tensor]: - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - padded_total_num_scheduled_tokens = _get_padded_token_len( - self.num_tokens_paddings, total_num_scheduled_tokens - ) - - is_mm_embed = self.is_mm_embed_cpu - is_mm_embed[:padded_total_num_scheduled_tokens] = False - mm_embeds = list[torch.Tensor]() - req_start_idx = 0 - - for req_id in self.input_batch.req_ids: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] - req_state = self.requests[req_id] - num_computed_tokens = req_state.num_computed_tokens - - # TODO unroll loop and assume/enforce --disable_chunked_mm_input - # NOTE (NickLucche) here we diverge from logic in other runners, as - # we assume to only have whole mm items to process. Hence we avoid - # the intrinsic dynamism that `gather_mm_placeholders` introduces. - for mm_feature in req_state.mm_features: - pos_info = mm_feature.mm_position - start_pos = pos_info.offset - num_encoder_tokens = pos_info.length - - # The encoder output is needed if the two ranges overlap: - # [num_computed_tokens, - # num_computed_tokens + num_scheduled_tokens) and - # [start_pos, start_pos + num_encoder_tokens) - if start_pos >= num_computed_tokens + num_scheduled_tokens: - # The encoder output is not needed in this step. - break - if start_pos + num_encoder_tokens <= num_computed_tokens: - # The encoder output is already processed and stored - # in the decoder's KV cache. - continue - - start_idx = max(num_computed_tokens - start_pos, 0) - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens, - ) - assert start_idx < end_idx - - mm_hash = mm_feature.identifier - encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." - - assert pos_info.is_embed is None, ( - "Expected all positions to be contiguous and embeddings." - ) - - req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = True - - # Only whole mm items are processed - mm_embeds.append(encoder_output) - - req_start_idx += num_scheduled_tokens - - is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens].to(self.device) - - return mm_embeds, is_mm_embed - - def _get_model_inputs( - self, - input_ids: torch.Tensor, - mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None, - ): - if self.supports_mm_inputs: - mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) - - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - inputs_embeds = self.model.embed_input_ids( - input_ids, - multimodal_embeddings=mm_embeds, - is_multimodal=is_mm_embed, - ) - - return None, inputs_embeds - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. - return input_ids, None - - @torch.no_grad() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: IntermediateTensors | None = None, - ) -> ModelRunnerOutput | None: - if self.scheduler_output is not None: - raise RuntimeError( - "State error: sample_tokens() must be called " - "after execute_model() returns None." - ) - # Update cached state - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - - return self.kv_connector_no_forward(scheduler_output, self.vllm_config) - - mm_embed_inputs = None - if self.supports_mm_inputs: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embed_inputs = self._gather_mm_embeddings(scheduler_output) - - torch_xla.sync(wait=False) - - self.scheduler_output = scheduler_output - self.mm_embed_inputs = mm_embed_inputs - return None - - @torch.no_grad() - def sample_tokens( - self, grammar_output: "GrammarOutput | None" - ) -> ModelRunnerOutput: - if self.scheduler_output is None: - # Nothing to do (PP non-final rank case), output isn't used. - return None # type: ignore[return-value] - scheduler_output = self.scheduler_output - mm_embed_inputs = self.mm_embed_inputs - self.scheduler_output = None - self.mm_embed_inputs = None - - # Prepare inputs, the requests might be split into multiple - # executions, combine the result of each execution. - start_index = 0 - combined_selected_tokens: list[torch.Tensor] = [] - combined_logprobs: list[LogprobsLists] = [] - - # NOTE: setup current batch's metadata for kv connector. - # Currently, only verified with NixlConnector - with set_forward_context(None, self.vllm_config): - self.maybe_setup_kv_connector(scheduler_output) - - while start_index < self.input_batch.num_reqs: - attn_metadata, logits_indices, padded_num_reqs, num_reqs, end_index = ( - self._prepare_inputs(scheduler_output, start_index) - ) - input_ids, inputs_embeds = self._get_model_inputs( - self.input_ids, mm_embed_inputs - ) - torch_xla.sync(wait=False) - # Run the decoder - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=scheduler_output.total_num_scheduled_tokens, - ): - hidden_states = self.model( - input_ids=input_ids, - positions=self.position_ids, - inputs_embeds=inputs_embeds, - ) - hidden_states = self.select_hidden_states(hidden_states, logits_indices) - logits = self.compute_logits(hidden_states) - tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( - self.input_batch, padded_num_reqs, self.device - ) - if grammar_output is not None: - require_struct_decoding, grammar_bitmask_padded, arange = ( - self.prepare_structured_decoding_input(logits, grammar_output) - ) - logits = self.structured_decode( - require_struct_decoding, grammar_bitmask_padded, logits, arange - ) - selected_token_ids = self.sample_from_logits_func( - logits, tpu_sampling_metadata - ) - # NOTE (NickLucche) Use the original logits (before any penalties or - # temperature scaling) for the top-k logprobs. We can't enforce it - # due to recompilations outside torch.compiled code, so just make - # sure `sample_from_logits` does not modify the logits in-place. - logprobs = ( - self.gather_logprobs(logits, selected_token_ids) - if tpu_sampling_metadata.logprobs - else None - ) - - # Remove padding on cpu and keep dynamic op outside of xla graph. - selected_token_ids = selected_token_ids.cpu()[:num_reqs] - - combined_selected_tokens.append(selected_token_ids) - if tpu_sampling_metadata.logprobs: - combined_logprobs.append(logprobs.tolists()) - - start_index = end_index - - # NOTE: current kv load and save get h2d/d2h copies involved. - # Those copies are blocking. Once they become async., kv_save - # should be called right after each single forward pass, - # instead of the forwards of the entire input batch. - self.maybe_wait_for_kv_save() - finished_sending, finished_recving = self.get_finished_kv_transfers( - scheduler_output - ) - - selected_token_ids = torch.cat(combined_selected_tokens, dim=0) - if tpu_sampling_metadata.logprobs: - - def concat_lists(input_lists): - result = [] - for input_list in input_lists: - result.extend(input_list) - return result - - logprobs_lists = LogprobsLists( - logprob_token_ids=concat_lists( - [lp.logprob_token_ids for lp in combined_logprobs] - ), - logprobs=concat_lists([lp.logprobs for lp in combined_logprobs]), - sampled_token_ranks=concat_lists( - [lp.sampled_token_ranks for lp in combined_logprobs] - ), - ) - else: - logprobs_lists = None - - # Update the cache state concurrently. Code above will not block until - # we use `selected_token_ids`. Add mark_step if post-processing changes - request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] - discard_sampled_tokens_req_indices = [] - num_reqs = self.input_batch.num_reqs - for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): - assert req_id is not None - req_state = self.requests[req_id] - seq_len = ( - req_state.num_computed_tokens - + scheduler_output.num_scheduled_tokens[req_id] - ) - if seq_len >= req_state.num_tokens: - request_seq_lens.append((i, req_state, seq_len)) - else: - # Ignore the sampled token from the partial request. - # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators.get(i) - if generator is not None: - # This relies on cuda-specific torch-internal impl details - generator.set_offset(generator.get_offset() - 4) - - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) - - assert all( - req_id is not None for req_id in self.input_batch.req_ids[:num_reqs] - ), "req_ids contains None" - req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) - - prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} - for req_id in self.input_batch.req_ids[:num_reqs]: - prompt_logprobs_dict[req_id] = None - - max_gen_len = selected_token_ids.shape[-1] - if max_gen_len == 1: - valid_sampled_token_ids = selected_token_ids.tolist() - - # Mask out the sampled tokens that should not be sampled. - # TODO: Keep in sync with gpu_model_runner.py, in particular - # the "else" case here - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() - - # Append sampled tokens - for i, req_state, seq_len in request_seq_lens: - token_id = valid_sampled_token_ids[i][0] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids.append(token_id) - self.input_batch.num_tokens_no_spec[i] += 1 - - else: - valid_mask = selected_token_ids != INVALID_TOKEN_ID - gen_lens = valid_mask.sum(dim=1).tolist() - valid_sampled_token_ids = [ - seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens) - ] - self.input_batch.num_tokens_no_spec[:num_reqs] += gen_lens - for i, req_state, seq_len in request_seq_lens: - target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) - self.input_batch.token_ids_cpu[i, target_slice] = ( - valid_sampled_token_ids[i] - ) - req_state.output_token_ids.extend(valid_sampled_token_ids[i]) - - kv_connector_output = ( - None - if (finished_sending is None and finished_recving is None) - else KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving, - ) - ) - - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - kv_connector_output=kv_connector_output, - ) - - # Check there are no new graphs compiled - all the graphs should be - # captured and compiled during warm up. - self._verify_num_xla_graphs("execute_model") - - return model_runner_output - - def update_config(self, overrides: dict[str, Any]) -> None: - # TODO: TPU config may need extra validation - # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754 - allowed_config_names = {"load_config", "model_config"} - for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, ( - f"Config `{config_name}` not supported. " - f"Allowed configs: {allowed_config_names}" - ) - config = getattr(self, config_name) - new_config = update_config(config, config_overrides) - setattr(self, config_name, new_config) - - def load_model(self) -> None: - self.device = self.device_config.device - - # NOTE(woosuk): While the executor assigns the TP ranks to the worker - # process, the ranks can be different from the ranks internally assigned - # by the xm runtime. Therefore, there is a mismatch in the rank - # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. - # This is not a problem in linear layers because all-reduce is - # rank-agnostic. However, it matters for all-gather as the ranks - # determine the order of concatenating the output tensors. - # As a workaround, we use the xm's rank assignment only when loading - # the embedding weights. - xm_tp_rank = xr.global_ordinal() - with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", - return_value=xm_tp_rank, - ): - try: - if self.use_spmd: - tpu_loader = TPUModelLoader( - load_config=self.vllm_config.load_config - ) - model = tpu_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.vllm_config.model_config, - mesh=self.mesh, - ) - else: - model_loader = get_model_loader(self.load_config) - logger.info("Loading model from scratch...") - model = model_loader.load_model( - vllm_config=self.vllm_config, model_config=self.model_config - ) - except RuntimeError as e: - raise RuntimeError( - f"Unable to load model, a likely reason is the model is " - "too large for the current device's HBM memory. " - "Consider switching to a smaller model " - "or sharding the weights on more chips. " - f"See the detailed error: {e}" - ) from e - if self.lora_config is not None: - model = self.load_lora_model(model, self.vllm_config, self.device) - replace_set_lora(model) - - # Sync all pending XLA execution during model initialization and weight - # loading. - torch_xla.sync(wait=False) - xm.wait_device_ops() - if not hasattr(self, "model"): - self.model = model - self.sampler = TPUSampler() - - def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, ( - "Cannot reload weights before model is loaded." - ) - model_loader = get_model_loader(self.load_config) - logger.info("Reloading weights inplace...") - model_loader.load_weights(self.model, model_config=self.model_config) - - @torch.no_grad() - def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: - if self.supports_mm_inputs: - input_ids = None - inputs_embeds = torch.zeros( - (num_tokens, self.inputs_embeds_size), - dtype=self.dtype, - device=self.device, - ) - else: - input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device) - inputs_embeds = None - actual_num_reqs = min(num_tokens, num_reqs) - position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) - padded_num_slices = _get_padded_num_kv_cache_update_slices( - num_tokens, self.max_num_reqs, self.block_size - ) - num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to( - self.device - ) - slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to( - self.device - ) - block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to( - self.device - ) - query_lens = [1] * num_reqs - query_start_loc = torch.cumsum( - torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32 - ).to(self.device) - context_lens = torch.ones((num_reqs,), dtype=torch.int32).to(self.device) - num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32).to(self.device) - attn_metadata = PallasMetadata( - slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - query_start_loc=query_start_loc, - num_seqs=num_seqs, - num_kv_update_slices=num_kv_update_slices, - num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block, - ) - - if self.supports_mm_inputs: - torch._dynamo.mark_dynamic(inputs_embeds, 0) - else: - torch._dynamo.mark_dynamic(input_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1)) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) - - layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() - per_layer_attn_metadata = { - layer_name: attn_metadata for layer_name in layer_names - } - - with ( - self.maybe_select_dummy_loras( - self.lora_config, np.array([num_tokens], dtype=np.int32) - ), - set_forward_context(per_layer_attn_metadata, self.vllm_config, 0), - ): - out = self.model( - input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds - ) - self._hidden_states_dtype = out.dtype - - def _set_active_loras( - self, prompt_lora_mapping, token_lora_mapping, lora_requests - ) -> None: - torch_xla.sync(wait=False) # Captures input updates - super()._set_active_loras( - prompt_lora_mapping, token_lora_mapping, lora_requests - ) - torch_xla.sync(wait=False) # Captures metadata updates - - def _precompile_mm_encoder(self) -> None: - if not self.supports_mm_inputs: - return - - # Pre-compile MM encoder for all supported data modalities. - hf_config = self.vllm_config.model_config.hf_config - - mm_budget = self.mm_budget - assert mm_budget is not None - - max_items_per_seq_by_modality = mm_budget.max_items_per_batch_by_modality # noqa: E501 - - for mode, max_items_per_seq in max_items_per_seq_by_modality.items(): - logger.info( - "Compiling Multimodal %s Encoder with different input shapes.", mode - ) - start = time.perf_counter() - # No padding for MM encoder just yet. - for num_items in range(1, max_items_per_seq + 1): - logger.info(" -- mode: %s items: %d", mode, num_items) - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - mode, - num_items, - ) - # Run multimodal encoder. - torch_xla.sync(wait=False) - mm_embeds = self.model.embed_multimodal(**batched_dummy_mm_inputs) - torch_xla.sync(wait=False) - num_patches = mm_embeds[0].shape[0] - items_size = num_patches * num_items - - # NOTE (NickLucche) pre-compile `embed_input_ids` when mm - # embeddings are present. We assume `--disable-mm-chunked`, - # hence only whole items can be scheduled. This implies we just - # need to compile when `num_items` fit the (padded) `input_ids` - for num_tokens in self.num_tokens_paddings: - if num_tokens >= items_size: - # XLA Workaround: if torch.zeros(..device) is used, XLA - # compiles a scalar+expansion op, which won't match - # the graph generated at runtime. CPU->TPU must be used - placeholders_ids = torch.zeros( - num_tokens, dtype=torch.int32, device="cpu" - ) - # Align placeholders and actual num mm_embeddings. - placeholders_ids[:items_size] = hf_config.image_token_index - - placeholders_ids = placeholders_ids.to(self.device) - - mm_mask = torch.tensor([False] * num_tokens) - mm_mask[:items_size] = True - mm_mask = mm_mask.to(self.device) - # Assign outputs or the graph will be cut short. - a, b = self._get_model_inputs( - placeholders_ids, - mm_embed_inputs=([mm_embeds], mm_mask), - ) - assert a is None - torch_xla.sync(wait=False) - - # Pre-compile `embed_input_ids` when mm_embeddings are not - # present. Chunk is only made of text, no mm_placeholders. - for num_tokens in self.num_tokens_paddings: - placeholders_ids = torch.zeros( - num_tokens, dtype=torch.int32, device="cpu" - ) - placeholders_ids = placeholders_ids.to(self.device) - a, b = self._get_model_inputs( - placeholders_ids, - mm_embed_inputs=None, - ) - assert a is None - torch_xla.sync(wait=False) - - xm.wait_device_ops() - end = time.perf_counter() - logger.info( - "Multimodal %s Encoder compilation finished in in %.2f [secs].", - mode, - end - start, - ) - - def _precompile_backbone(self) -> None: - logger.info("Compiling the model with different input shapes.") - start = time.perf_counter() - for num_tokens in self.num_tokens_paddings: - logger.info(" -- num_tokens: %d", num_tokens) - self._dummy_run( - num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req - ) - if self.most_model_len is not None: - self._dummy_run( - num_tokens, - self.num_reqs_most_model_len, - self.num_blocks_per_most_len_req, - ) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("model backbone") - - def _precompile_select_hidden_states(self) -> None: - # Compile hidden state selection function for bucketed - # n_tokens x max_num_reqs. Graph is really small so this is fine. - logger.info("Compiling select_hidden_states with different input shapes.") - start = time.perf_counter() - hsize = self.model_config.get_hidden_size() - for num_tokens in self.num_tokens_paddings: - dummy_hidden = torch.zeros( - (num_tokens, hsize), device=self.device, dtype=self._hidden_states_dtype - ) - torch._dynamo.mark_dynamic(dummy_hidden, 0) - for num_reqs in self.num_reqs_paddings: - indices = torch.zeros(num_reqs, dtype=torch.int32, device=self.device) - torch._dynamo.mark_dynamic(indices, 0) - self.select_hidden_states(dummy_hidden, indices) - logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs) - # Requests can't be more than tokens. But do compile for the - # next bigger value in case num_tokens uses bucketed padding. - if num_reqs >= min(num_tokens, self.max_num_reqs): - break - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("select_hidden_states") - - def _precompile_compute_logits(self) -> None: - logger.info("Compiling compute_logits with different input shapes.") - start = time.perf_counter() - hsize = self.model_config.get_hidden_size() - for num_reqs in self.num_reqs_paddings: - dummy_hidden = torch.zeros( - (num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype - ) - torch._dynamo.mark_dynamic(dummy_hidden, 0) - self.compute_logits(dummy_hidden) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("compute_logits") - - def _precompile_structured_decoding(self) -> None: - logger.info("Compiling structured_decoding with different input shapes.") - start = time.perf_counter() - for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros( - (num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype, - ) - dummy_require_struct_decoding = self.require_structured_out_cpu[ - :num_reqs - ].to(self.device) - dummy_grammar_bitmask = self.grammar_bitmask_cpu[:num_reqs].to(self.device) - # The first dimension of the above 3 dummy tensors cannot be - # mark_dynamic because some operations in structured_decode require - # them to be static. - arange = self.structured_decode_arange.to(self.device) - self.structured_decode( - dummy_require_struct_decoding, - dummy_grammar_bitmask, - dummy_logits, - arange, - ) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("structured_decoding") - - def _precompile_sample_from_logits(self) -> None: - logger.info("Compiling sample_from_logits with different input shapes.") - start = time.perf_counter() - for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros( - (num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype, - ) - # The first dimension of dummy_logits cannot be mark_dynamic - # because some operations in the sampler require it to be static. - for all_greedy in [False, True]: - generate_params_if_all_greedy = not all_greedy - sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( - self.input_batch, - num_reqs, - self.device, - generate_params_if_all_greedy, - ) - sampling_metadata.all_greedy = all_greedy - with self.maybe_select_dummy_loras( - self.lora_config, np.array([num_reqs], dtype=np.int32) - ): - self.sample_from_logits_func(dummy_logits, sampling_metadata) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("sample_from_logits") - - def _precompile_gather_logprobs(self) -> None: - logger.info("Compiling gather_logprobs with different input shapes.") - start = time.perf_counter() - for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros( - (num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype, - ) - dummy_tokens = torch.zeros((num_reqs, 1), dtype=torch.int64).to(self.device) - with self.maybe_select_dummy_loras( - self.lora_config, np.array([num_reqs], dtype=np.int32) - ): - self.gather_logprobs(dummy_logits, dummy_tokens) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("gather_logprobs") - - def capture_model(self) -> None: - """ - Precompile all the subgraphs with possible input shapes. - """ - with self.maybe_setup_dummy_loras(self.lora_config): - self._precompile_mm_encoder() - self._precompile_backbone() - self._precompile_select_hidden_states() - self._precompile_compute_logits() - self._precompile_structured_decoding() - self._precompile_sample_from_logits() - self._precompile_gather_logprobs() - - def profile_run( - self, - num_tokens: int, - ) -> None: - # Profile with multimodal encoder & encoder cache. - if self.supports_mm_inputs: - mm_config = self.model_config.multimodal_config - if mm_config is not None and mm_config.skip_mm_profiling: - logger.info( - "Skipping memory profiling for multimodal encoder and " - "encoder cache." - ) - else: - mm_budget = self.mm_budget - assert mm_budget is not None - - # TODO: handle encoder-decoder models once we support them. - if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ - dummy_modality - ] - - logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the " - "maximum feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) - - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) - - # Run multimodal encoder. - # Isolate encoder graph from post-processing to minimize - # impact of recompilation until it's fixed. - start = time.perf_counter() - torch_xla.sync(wait=False) - dummy_encoder_outputs = self.model.embed_multimodal( - **batched_dummy_mm_inputs - ) - torch_xla.sync(wait=False) - xm.wait_device_ops() - end = time.perf_counter() - logger.info( - "Multimodal Encoder profiling finished in %.2f [secs].", - end - start, - ) - - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - - # Trigger compilation for general shape. - self._dummy_run( - num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req - ) - if self.most_model_len is not None: - self._dummy_run( - num_tokens, - self.num_reqs_most_model_len, - self.num_blocks_per_most_len_req, - ) - - torch_xla.sync(wait=False) - xm.wait_device_ops() - self.encoder_cache.clear() - gc.collect() - - def maybe_setup_cross_layer_kv_sharing( - self, - kv_caches: dict[str, torch.Tensor], - kv_cache_config: KVCacheConfig, - ) -> None: - """ - Add layers that re-use KV cache to KV cache group of its target layer. - Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` - """ - if not self.shared_kv_cache_layers: - # No cross-layer KV sharing, return - return - - add_kv_sharing_layers_to_kv_cache_groups( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - ) - - for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): - logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) - kv_caches[layer_name] = kv_caches[target_layer_name] - - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: - """ - Initialize KV cache based on `kv_cache_config`. - Args: - kv_cache_config: Configuration for the KV cache, including the KV - cache size of each layer - """ - if len(kv_cache_config.kv_cache_groups) > 1: - raise NotImplementedError( - "Hybrid models with more than one KV cache type are not supported yet." - ) - - if ( - kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size - != self.block_size - ): - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - block_sizes=[ - kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size - ], - kernel_block_sizes=[ - kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size - ], - ) - # Verify dtype compatibility between block_table_cpu and input_batch - assert ( - self.block_table_cpu.dtype - == self.input_batch.block_table[0].get_cpu_tensor().dtype - ) - - kv_cache_sizes = {} - for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - assert len(kv_cache_tensor.shared_by) == 1, ( - "KV cache tensor shared by multiple layers is not supported in TPU." - ) - kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size - - kv_caches: dict[str, torch.Tensor] = {} - for kv_cache_group in kv_cache_config.kv_cache_groups: - kv_cache_spec = kv_cache_group.kv_cache_spec - for layer_name in kv_cache_group.layer_names: - tensor_size = kv_cache_sizes[layer_name] - assert tensor_size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_size // kv_cache_spec.page_size_bytes # noqa - if isinstance(kv_cache_spec, AttentionSpec): - if self.use_spmd: - num_kv_heads = kv_cache_spec.num_kv_heads - assert self.original_parallel_config is not None - tp_size = self.original_parallel_config.tensor_parallel_size - # TODO: Handle kv cache duplication under SPMD mode. - assert num_kv_heads % tp_size == 0, ( - f"num_kv_heads {num_kv_heads} must be divisible by " - f"tp_size {tp_size} under SPMD mode" - ) - kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - ) - dtype = kv_cache_spec.dtype - - tpu_kv_cache = torch.zeros(kv_cache_shape, dtype=dtype).to( - self.device - ) - - kv_caches[layer_name] = tpu_kv_cache - else: - raise NotImplementedError - - # Set up cross-layer KV cache sharing if needed - self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config) - - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches, - ) - - if self.use_spmd: - # Shard KV Cache - for cache in self.kv_caches: - xs.mark_sharding(cache, self.mesh, (None, "x", None, None)) - - if has_kv_transfer_group(): - get_kv_transfer_group().register_kv_caches(kv_caches) - get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) - - def reset_dynamo_cache(self): - # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs` - # since the compiled model object of the language backbone of a - # multimodal model needs to be extracted via `get_language_model`. - if self.model_config.is_multimodal_model: - compiled_model = self.model.get_language_model().model - else: - compiled_model = self.model.model - if isinstance(compiled_model, TorchCompileWithNoGuardsWrapper): - logger.info("Clear dynamo cache and cached dynamo bytecode.") - torch._dynamo.eval_frame.remove_from_cache( - compiled_model.original_code_object() - ) - # Reset the wrapper to re-initialize. - compiled_model.compiled = False - TorchCompileWithNoGuardsWrapper.__init__(compiled_model) - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def select_hidden_states(self, hidden_states, indices_do_sample): - return hidden_states[indices_do_sample] - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor: - return self.model.compute_logits(sample_hidden_states) - - # TODO: Under SPMD mode, sample_from_logits has correctness issue. - # Re-enable the torch.compile once the issue is fixed in torchxla. - # @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def sample_from_logits( - self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata - ) -> torch.Tensor: - """ - Sample with xla-friendly function. This function is to be traced - separately from `forward` for lighter compilation overhead. - """ - if sampling_metadata.all_greedy: - out_tokens = torch.argmax(logits, dim=-1, keepdim=True) - else: - out_tokens = self.sampler(logits, sampling_metadata).sampled_token_ids - return out_tokens - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def gather_logprobs( - self, logits: torch.Tensor, sampled_tokens: torch.Tensor - ) -> LogprobsTensors: - """ - Gather the top_logprobs with corresponding tokens. Use a fixed number - of logprobs as an alternative to having multiple pre-compiled graphs. - Select the number of logprobs actually demanded by each request on CPU. - """ - logprobs = self.sampler.compute_logprobs(logits) - return self.sampler.gather_logprobs( - logprobs, - self.model_config.max_logprobs, - token_ids=sampled_tokens.squeeze(-1), - ) - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def structured_decode( - self, - require_struct_decoding: torch.Tensor, - grammar_bitmask: torch.Tensor, - logits: torch.Tensor, - arange: torch.Tensor, - ) -> torch.Tensor: - return torch.where( - require_struct_decoding, - self.apply_grammar_bitmask(logits, grammar_bitmask, arange), - logits, - ) - - def apply_grammar_bitmask( - self, logits: torch.Tensor, grammar_bitmask: torch.Tensor, arange: torch.Tensor - ): - assert logits.shape[0] == grammar_bitmask.shape[0] - logits_cloned = logits.clone() - for i in range(logits.shape[0]): - unpacked_bitmask = ( - torch.bitwise_right_shift(grammar_bitmask[i][:, None], arange[None, :]) - & 1 - ) == 0 - unpacked_bitmask = unpacked_bitmask.reshape(-1)[: self.vocab_size] - logits_cloned[i] = logits_cloned[i].masked_fill( - unpacked_bitmask, -float("inf") - ) - return logits_cloned - - def embed_multimodal(self, *args, **kwargs): - return self.model.embed_multimodal(*args, **kwargs) - - def embed_input_ids(self, *args, **kwargs): - return self.model.embed_input_ids(*args, **kwargs) - - def prepare_structured_decoding_input( - self, logits: torch.Tensor, grammar_output: "GrammarOutput" - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - grammar_bitmask = grammar_output.grammar_bitmask - num_reqs, _ = logits.shape - - # Reset pre-allocated tensors - self.grammar_bitmask_cpu.zero_() - self.require_structured_out_cpu.zero_() - - cumulative_mask_idx = 0 - for req_id in grammar_output.structured_output_request_ids: - if req_id not in self.input_batch.req_id_to_index: - continue - batch_index = self.input_batch.req_id_to_index[req_id] - self.grammar_bitmask_cpu[batch_index] = torch.from_numpy( - grammar_bitmask[cumulative_mask_idx] - ) - # It's not guaranteed that all requests in this batch require - # structured output, so create a bool tensor to represent - # the requests that need structured output. - self.require_structured_out_cpu[batch_index] = True - cumulative_mask_idx += 1 - - return ( - self.require_structured_out_cpu[:num_reqs].to(logits.device), - self.grammar_bitmask_cpu[:num_reqs].to(logits.device), - self.structured_decode_arange.to(logits.device), - ) - - def _get_mm_dummy_batch( - self, - modality: str, - max_items_per_batch: int, - ) -> BatchedTensorInputs: - """Dummy data for profiling and precompiling multimodal models.""" - assert self.mm_budget is not None - - dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( - model_config=self.model_config, - seq_len=self.max_model_len, - mm_counts={modality: 1}, - cache=self.mm_budget.cache, - ) - dummy_mm_data = dummy_decoder_data.multi_modal_data - - # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data[modality][0] - dummy_mm_items = [dummy_mm_item] * max_items_per_batch - - return next( - grouped_mm_kwargs - for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - ) - ) - - -def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: - logger.info("Preparing request paddings:") - # assert min_req_size is power of 2 - assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0 - paddings: list = [] - num = max(MIN_NUM_SEQS, min_req_size) - while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num): - paddings.append(num) - logger.info(" %d", num) - num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size) - return paddings - - -def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int: - res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() - return min(res, upper_limit) - - -def _get_token_paddings( - min_token_size: int, max_token_size: int, padding_gap: int -) -> list[int]: - """Generate a list of padding size, starting from min_token_size, - ending with a number that can cover max_token_size - - If padding_gap == 0 then: - increase 2X each time (exponential) - else: - first increase the size to twice, - then increase the padding size by padding_gap. - """ - # assert min_token_size is power of 2 - assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0 - paddings = [] - num = min_token_size - - if padding_gap == 0: - logger.info("Using exponential token paddings:") - while True: - logger.info(" %d", num) - paddings.append(num) - if num >= max_token_size: - break - num *= 2 - else: - logger.info("Using incremental token paddings:") - while num <= padding_gap: - logger.info(" %d", num) - paddings.append(num) - num *= 2 - num //= 2 - while num < max_token_size: - num += padding_gap - logger.info(" %d", num) - paddings.append(num) - - return paddings - - -def _get_padded_token_len(paddings: list[int], x: int) -> int: - """Return the first element in paddings list greater or equal to x.""" - index = bisect.bisect_left(paddings, x) - assert index < len(paddings) - return paddings[index] - - -def _get_padded_num_kv_cache_update_slices( - num_tokens: int, max_num_reqs: int, page_size: int -) -> int: - """Calculates the padded number of KV cache update slices to avoid - recompilation.""" - # NOTE(chengjiyao): let's say R_i is the token num for i-th request, - # so it occupies most 2 + R_i // page_size pages. The total maximum - # possible number of pages needed is sum(2 + R_i // page_size), which - # is <= 2 * max_num_reqs + sum(R_i) // page_size - # = 2 * max_num_reqs + num_tokens // page_size - padded_num_slices = 2 * max_num_reqs + num_tokens // page_size - padded_num_slices = min(padded_num_slices, num_tokens) - return padded_num_slices - - -def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: - """Find the optimum number of slices to copy per Pallas program instance. - - Increasing the number of slices copied in one instance of the kernel program - will increase HBM bandwidth utilization via more in-flight DMAs. - - However, it will also use more VMEM, and experimentally, we observed - performance regression at 128 slices on v6e, likely due to running - out of scalar registers. Thus this function will limit the number of - slices to 64. - """ - # The default vmem_limit_bytes of a pallas kernel is 32MB. Here we - # calculate num_slices_per_block based on 16MB in case any register spills. - vmem_limit = 16 * 1024 * 1024 - num_slices_per_block = vmem_limit // page_size_bytes - assert num_slices_per_block > 0, "Number of slices should be positive" - num_slices_per_block = prev_power_of_2(num_slices_per_block) - if num_slices_per_block > 64: - num_slices_per_block = 64 - return num_slices_per_block - - -def replace_set_lora(model): - def _tpu_set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, - ): - # TODO: The integer index leads to a recompilation, but converting it - # to a tensor doesn't seem to work anymore. This might be fixed with a - # later release of torch_xla. - self._original_set_lora(index, lora_a, lora_b, embeddings_tensor) - torch_xla.sync(wait=False) - - def _tpu_reset_lora(self, index: int): - self._original_reset_lora(index) - torch_xla.sync(wait=False) - - for _, module in model.named_modules(): - if isinstance(module, BaseLayerWithLoRA): - module._original_set_lora = module.set_lora - module._original_reset_lora = module.reset_lora - module.set_lora = _tpu_set_lora.__get__( # type: ignore[method-assign] - module, module.__class__ - ) - module.reset_lora = _tpu_reset_lora.__get__( # type: ignore[method-assign] - module, module.__class__ - )