diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py deleted file mode 100644 index e5a0cf7420497..0000000000000 --- a/vllm/v1/attention/backends/pallas.py +++ /dev/null @@ -1,316 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass - -import torch - -from vllm.attention.backends.abstract import ( - AttentionBackend, - AttentionImpl, - AttentionLayer, - AttentionType, -) -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.utils.math_utils import cdiv, next_power_of_2 - -logger = init_logger(__name__) - -# TPU requires the head size to be a multiple of 128. -TPU_HEAD_SIZE_ALIGNMENT = 128 - -# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8 -# from to fp32 directly. That's why it has a dtype mapping different from GPU -TPU_STR_DTYPE_TO_TORCH_DTYPE = { - "half": torch.half, - "bfloat16": torch.bfloat16, - "float": torch.float, - "fp8": torch.float8_e4m3fn, - "fp8_e4m3": torch.float8_e4m3fn, - "fp8_e5m2": torch.float8_e5m2, - "int8": torch.int8, - "uint8": torch.uint8, -} - -import tpu_inference # noqa: F401 - -# Note(weiyulin): some static functions are still used by tpu-inference -class PallasAttentionBackend(AttentionBackend): - @staticmethod - def get_name() -> str: - return "PALLAS" - - @staticmethod - def get_impl_cls() -> type["PallasAttentionBackendImpl"]: - return PallasAttentionBackendImpl - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - padded_head_size = ( - cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - ) - return (num_blocks, block_size, num_kv_heads * 2, padded_head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - raise RuntimeError("swap_blocks is not used for the TPU backend.") - - # In recent TPU generations, up to v6e, the SMEM size is 1MB. The - # block_tables within the PallasMetadata constitute almost the entire SMEM - # requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here - # we simply make sure that the size is smaller than half of SMEM capacity. - @staticmethod - def get_min_page_size(vllm_config: VllmConfig) -> int: - max_num_page_per_req = ( - 1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4 - ) - min_page_size = cdiv( - vllm_config.model_config.max_model_len, max_num_page_per_req - ) - min_page_size = 1 << (min_page_size - 1).bit_length() - return min_page_size - - @staticmethod - def get_max_num_seqs(model_len: int, page_size: int) -> int: - num_page_per_req = cdiv(model_len, page_size) - return 1024 * 1024 // 2 // num_page_per_req // 4 - - # TPU has limited SREGs (scalar registers), if page_size is too small, we - # can spill SREGs easily which leads to bad performance. The strategy we - # apply here is trying to split max-model-len to 16 pages which make the - # spill less likely. Meanwhile we make sure the page size is in [16, 256]. - @staticmethod - def get_page_size(vllm_config: VllmConfig) -> int: - # TODO: This is a temporary fix for vmem OOM. - # For long model length, we use 16 page-size to avoid too much - # VMEM spill. A more robust solution should be implemented to - # handle VREG spills. - if vllm_config.model_config.max_model_len > 8192: - return 16 - page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16 - if page_size <= 16: - return 16 - if page_size >= 256: - return 256 - return page_size - - -@dataclass -class PallasMetadata: - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Used in the PallasAttentionBackendImpl - slot_mapping: torch.Tensor - block_tables: torch.Tensor - context_lens: torch.Tensor - query_start_loc: torch.Tensor - num_seqs: torch.Tensor - num_kv_update_slices: torch.Tensor - num_slices_per_kv_cache_update_block: int - - -class PallasAttentionBackendImpl(AttentionImpl): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: list[float] | None, - sliding_window: int | None, - kv_cache_dtype: str, - logits_soft_cap: float | None = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: int | None = None, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.sliding_window = sliding_window - self.logits_soft_cap = logits_soft_cap - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if alibi_slopes is not None: - raise NotImplementedError("Alibi slopes is not supported.") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl" - ) - - self.kv_cache_quantized_dtype = None - if kv_cache_dtype != "auto": - self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get( - kv_cache_dtype.lower().strip() - ) - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: PallasMetadata, - output: torch.Tensor | None = None, - output_scale: torch.Tensor | None = None, - output_block_scale: torch.Tensor | None = None, - ) -> torch.Tensor: - """Forward pass with Pallas attention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache: shape = - [num_blocks, block_size, num_kv_heads * 2, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for PallasAttentionBackendImpl" - ) - - # For determine_available_memory case. - if kv_cache.numel() == 0: - if output is None: - output = torch.ones_like(query) - return output - - num_tokens, hidden_size = query.shape - query = query.view(num_tokens, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - padded_head_size = ( - cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - ) - query = torch.nn.functional.pad( - query, (0, padded_head_size - self.head_size), value=0.0 - ) - key = torch.nn.functional.pad( - key, (0, padded_head_size - self.head_size), value=0.0 - ) - value = torch.nn.functional.pad( - value, (0, padded_head_size - self.head_size), value=0.0 - ) - - if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0: - # Write input keys and values to the KV cache. - # Skip this if sharing KV cache with an earlier attention layer. - slot_mapping = attn_metadata.slot_mapping - write_to_kv_cache( - key, - value, - kv_cache, - slot_mapping, - attn_metadata.num_slices_per_kv_cache_update_block, - attn_metadata.num_kv_update_slices, - self.kv_cache_quantized_dtype, - layer._k_scale_float, - layer._v_scale_float, - ) - - if self.kv_cache_quantized_dtype is not None and ( - layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0 - ): - raise ValueError("k_scale_float and v_scale_float must be non-zero") - output = torch.ops.xla.ragged_paged_attention( - query, - kv_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - attn_metadata.query_start_loc, - attn_metadata.num_seqs, - # By default, the system utilizes optimized block size and - # vmem_limit_bytes parameters from the kernel repository. However, - # these can be manually adjusted for debugging if necessary. - num_kv_pages_per_block=None, - num_queries_per_block=None, - vmem_limit_bytes=None, - use_kernel=True, - sm_scale=self.scale, - sliding_window=self.sliding_window, - soft_cap=self.logits_soft_cap, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - ) - - if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - output = output[:, :, : self.head_size] - - return output.reshape(num_tokens, hidden_size) - - -def write_to_kv_cache( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - num_slices_per_kv_cache_update_block: int, - num_kv_update_slices: torch.Tensor, - kv_cache_quantized_dtype: torch.dtype | None = None, - k_scale: float = 1.0, - v_scale: float = 1.0, -) -> None: - """Write the key and values to the KV cache. - - Args: - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size] - num_slices_per_kv_cache_update_block: int - """ - _, page_size, num_combined_kv_heads, head_size = kv_cache.shape - head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - - if kv_cache_quantized_dtype is not None: - dtype_info = torch.finfo(kv_cache_quantized_dtype) - key = key.to(torch.float32) / k_scale - # NOTE: clamp is added here to avoid out of range of quantized dtype - key = torch.clamp(key, dtype_info.min, dtype_info.max) - key = key.to(kv_cache_quantized_dtype) - value = value.to(torch.float32) / v_scale - value = torch.clamp(value, dtype_info.min, dtype_info.max) - value = value.to(kv_cache_quantized_dtype) - - kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size) - - torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) - - kv_cache = kv_cache.flatten(0, 1) - new_kv_cache = torch.ops.xla.kv_cache_update_op( - kv, - slot_mapping, - kv_cache, - num_kv_update_slices, - page_size, - num_slices_per_kv_cache_update_block, - ) - # NOTE: the in-place copy will be optimized away by XLA compiler. - kv_cache.copy_(new_kv_cache)