mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
[Hardware][TPU]Enable ragged paged attention kernel and resolve recompilation issue (#14310)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
04222984f8
commit
0578e5a462
@ -17,9 +17,9 @@ ray[default]
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
# These are the 2 tunable parameters of the paged attention Pallas kernel.
|
||||
NUM_QUERIES_PER_BLOCK = 16
|
||||
NUM_QUERIES_PER_BLOCK = 32
|
||||
NUM_KV_PAGES_PER_BLOCK = 128
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@ -115,6 +115,17 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
tpu_version = torch_xla.tpu.version()
|
||||
if tpu_version < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
|
||||
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
|
||||
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
|
||||
if tpu_version == 4:
|
||||
self.vmem_limit_bytes = 16 * 1024 * 1024
|
||||
else:
|
||||
self.vmem_limit_bytes = 64 * 1024 * 1024
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
@ -131,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = ([num_kv_heads, num_blocks, block_size, head_size],
|
||||
[num_kv_heads, num_blocks, block_size, head_size])
|
||||
kv_cache = ([num_blocks, block_size, num_kv_heads, head_size],
|
||||
[num_blocks, block_size, num_kv_heads, head_size])
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@ -154,10 +165,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
query = query * self.scale
|
||||
# use_kernel switches between using kernel or reference implementation
|
||||
# (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890).
|
||||
use_kernel = False
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
@ -168,8 +175,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata.num_seqs,
|
||||
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
|
||||
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
|
||||
use_kernel=use_kernel,
|
||||
)
|
||||
vmem_limit_bytes=self.vmem_limit_bytes,
|
||||
use_kernel=True,
|
||||
sm_scale=self.scale)
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
@ -186,16 +194,15 @@ def write_to_kv_cache(
|
||||
Args:
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
k_cache = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
v_cache = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
k_cache = [num_blocks, block_size, num_kv_heads, head_size]
|
||||
v_cache = [num_blocks, block_size, num_kv_heads, head_size]
|
||||
|
||||
"""
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
||||
|
||||
key = key.flatten(0, 1)
|
||||
value = value.flatten(0, 1)
|
||||
key_cache = key_cache.flatten(0, 2)
|
||||
value_cache = value_cache.flatten(0, 2)
|
||||
key_cache = key_cache.flatten(0, 1)
|
||||
value_cache = value_cache.flatten(0, 1)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
key_cache.index_copy_(0, slot_mapping, key)
|
||||
value_cache.index_copy_(0, slot_mapping, value)
|
||||
|
||||
@ -14,7 +14,7 @@ import torch_xla.runtime as xr
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@ -416,8 +416,8 @@ class TPUModelRunner:
|
||||
num_scheduled_tokens_per_req)
|
||||
|
||||
# Do the padding and copy the tensors to the TPU.
|
||||
padded_total_num_scheduled_tokens = _get_padded_number(
|
||||
total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK)
|
||||
padded_total_num_scheduled_tokens = _get_padded_token_len(
|
||||
total_num_scheduled_tokens)
|
||||
self.input_ids = self.input_ids_cpu[:
|
||||
padded_total_num_scheduled_tokens].to(
|
||||
self.device)
|
||||
@ -428,23 +428,22 @@ class TPUModelRunner:
|
||||
slot_mapping = self.slot_mapping_cpu[:
|
||||
padded_total_num_scheduled_tokens].to(
|
||||
self.device)
|
||||
padded_block_table = self.block_table_cpu[:
|
||||
padded_total_num_scheduled_tokens]
|
||||
padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
||||
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
|
||||
padded_block_table = padded_block_table.to(self.device)
|
||||
query_start_loc = self.query_start_loc_cpu[:
|
||||
padded_total_num_scheduled_tokens
|
||||
+ 1].to(self.device)
|
||||
seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(
|
||||
block_tables = block_tables.to(self.device)
|
||||
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
|
||||
self.device)
|
||||
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
|
||||
|
||||
attn_metadata = PallasMetadata(
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=padded_block_table,
|
||||
block_tables=block_tables,
|
||||
context_lens=seq_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
num_seqs=num_reqs,
|
||||
num_seqs=torch.tensor([num_reqs],
|
||||
dtype=torch.int32,
|
||||
device=self.device),
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
# request in the batch. While we should not sample any token from this
|
||||
@ -693,29 +692,34 @@ class TPUModelRunner:
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
inputs_embeds = None
|
||||
actual_num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
position_ids = torch.zeros(num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = torch.zeros(num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
query_lens = [1] * num_tokens
|
||||
block_tables = torch.zeros(
|
||||
(self.max_num_reqs, self.block_table_cpu.shape[1]),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
query_lens = [1] * self.max_num_reqs
|
||||
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32),
|
||||
dim=0,
|
||||
dtype=torch.int32).to(self.device)
|
||||
context_lens = torch.ones((num_tokens, ),
|
||||
context_lens = torch.ones((self.max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
num_seqs = torch.tensor([actual_num_reqs],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
attn_metadata = PallasMetadata(
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
num_seqs=num_tokens,
|
||||
num_seqs=num_seqs,
|
||||
)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
@ -724,9 +728,6 @@ class TPUModelRunner:
|
||||
torch._dynamo.mark_dynamic(input_ids, 0)
|
||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
assert self.model is not None
|
||||
@ -817,28 +818,6 @@ class ModelWrapperV1(nn.Module):
|
||||
inputs_embeds: The input embeddings of shape [num_tokens,
|
||||
hidden_size]. It is used for multimodal models.
|
||||
"""
|
||||
# Skip this in memory profiling at initialization.
|
||||
if kv_caches[0][0].numel() > 0:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
# work, we need to flatten the first three dimensions and modify
|
||||
# the slot_mapping accordingly.
|
||||
# kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
|
||||
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
head_indicies = torch.arange(0,
|
||||
num_kv_heads,
|
||||
device=slot_mapping.device,
|
||||
dtype=slot_mapping.dtype)
|
||||
head_indicies *= block_size * num_blocks
|
||||
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
|
||||
-1, num_kv_heads)
|
||||
slot_mapping = slot_mapping + head_indicies.view(1, -1)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
@ -866,3 +845,9 @@ class ModelWrapperV1(nn.Module):
|
||||
|
||||
def _get_padded_number(n: int, multiple: int) -> int:
|
||||
return ((n + multiple - 1) // multiple) * multiple
|
||||
|
||||
|
||||
def _get_padded_token_len(x: int) -> int:
|
||||
if x <= 16:
|
||||
return 16
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user