mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 06:15:02 +08:00
[ROCm][Perf] New design on ROCm AITER MHA backend Implementation (#25763)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
parent
2f1cc8cef1
commit
dc937175d4
@ -13,223 +13,204 @@ from vllm.attention.backends.abstract import (
|
|||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
)
|
)
|
||||||
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
|
split_decodes_prefills_and_extends,
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
_PARTITION_SIZE_ROCM = 256
|
_PARTITION_SIZE_ROCM = 256
|
||||||
|
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
import aiter
|
import aiter
|
||||||
|
from aiter.ops.triton.utils.device_info import get_num_sms
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
|
||||||
|
def block_size(x, head_dim):
|
||||||
|
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||||
|
|
||||||
|
def num_programs(head_dim):
|
||||||
|
return min(head_dim, get_num_sms())
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _vllm_layout_trans_kernel(
|
def cp_mha_gather_cache_kernel(
|
||||||
k_buffer_ptr,
|
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
|
||||||
v_buffer_ptr,
|
value_cache_ptr, # [num_blocks, page_size, num_head, head_size]
|
||||||
k_values_ptr,
|
key_ptr, # [num_tokens, num_heads, head_size]
|
||||||
v_values_ptr,
|
value_ptr, # [num_tokens, num_heads, head_size]
|
||||||
b_query_lens_loc,
|
block_table_ptr, # [num_batches, max_block_num]
|
||||||
b_seq_lens_loc,
|
cu_seqlens_kv_ptr, # [num_batches + 1]
|
||||||
block_table,
|
token_to_batch_ptr, # [max_cum_tokens]
|
||||||
block_table_stride_0,
|
seq_start_ptr, # [num_batches]
|
||||||
k_scale,
|
k_scale_ptr,
|
||||||
v_scale,
|
v_scale_ptr,
|
||||||
output_dtype: tl.constexpr,
|
num_heads,
|
||||||
E_DIM: tl.constexpr,
|
head_size,
|
||||||
|
x,
|
||||||
|
max_block_num,
|
||||||
|
num_tokens,
|
||||||
|
DEQUANT: tl.constexpr,
|
||||||
|
PAGE_SIZE: tl.constexpr,
|
||||||
|
CACHE_FORMAT: tl.constexpr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
NUM_PRGMS: tl.constexpr,
|
||||||
):
|
):
|
||||||
batch_idx = tl.program_id(0)
|
bid = tl.program_id(0)
|
||||||
block_idx = tl.program_id(1)
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
if DEQUANT:
|
||||||
|
k_scale = tl.load(k_scale_ptr)
|
||||||
|
v_scale = tl.load(v_scale_ptr)
|
||||||
|
|
||||||
batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + tl.arange(0, 2))
|
for token_id in tl.range(bid, num_tokens, NUM_PRGMS):
|
||||||
batch_query_start, batch_query_end = tl.split(batch_query_indexes)
|
key_ptr_offset = key_ptr + token_id * head_size * num_heads
|
||||||
query_len = batch_query_end - batch_query_start
|
value_ptr_offset = value_ptr + token_id * head_size * num_heads
|
||||||
|
batch_idx = tl.load(token_to_batch_ptr + token_id)
|
||||||
if query_len <= 1:
|
batch_start = tl.load(seq_start_ptr + batch_idx)
|
||||||
return
|
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
|
||||||
|
batch_offset = token_id - token_start + batch_start
|
||||||
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + tl.arange(0, 2))
|
block_offset = batch_offset // PAGE_SIZE
|
||||||
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
|
block_id = tl.load(
|
||||||
seq_len = batch_token_end - batch_token_start
|
block_table_ptr + max_block_num * batch_idx + block_offset
|
||||||
|
|
||||||
if block_idx * BLOCK_SIZE < seq_len:
|
|
||||||
block_mask = (
|
|
||||||
block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None]
|
|
||||||
) < seq_len
|
|
||||||
|
|
||||||
kv_idx = tl.load(
|
|
||||||
block_table + batch_idx * block_table_stride_0 + block_idx
|
|
||||||
).to(tl.int64)
|
|
||||||
|
|
||||||
kv_buffer_off = (
|
|
||||||
kv_idx * BLOCK_SIZE * E_DIM
|
|
||||||
+ tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM
|
|
||||||
+ tl.arange(0, E_DIM)[None, :]
|
|
||||||
)
|
)
|
||||||
k_vals = tl.load(k_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0)
|
slot_id = batch_offset % PAGE_SIZE
|
||||||
if k_vals.dtype.is_fp8():
|
|
||||||
k_vals = (k_vals.to(tl.float32) * tl.load(k_scale)).to(output_dtype)
|
|
||||||
else:
|
|
||||||
k_vals = k_vals.to(output_dtype)
|
|
||||||
|
|
||||||
v_vals = tl.load(v_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0)
|
if CACHE_FORMAT == "NHD":
|
||||||
if v_vals.dtype.is_fp8():
|
# for kv cache layout as
|
||||||
v_vals = (v_vals.to(tl.float32) * tl.load(v_scale)).to(output_dtype)
|
# K: [num_blocks, page_size, num_head, head_dim]
|
||||||
else:
|
# V: [num_blocks, page_size, num_head, head_dim]
|
||||||
v_vals = v_vals.to(output_dtype)
|
key_cache_ptr_offset = (
|
||||||
kv_values_off = (
|
key_cache_ptr
|
||||||
batch_token_start * E_DIM
|
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||||
+ block_idx * BLOCK_SIZE * E_DIM
|
+ slot_id * num_heads * head_size
|
||||||
+ tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM
|
)
|
||||||
+ tl.arange(0, E_DIM)[None, :]
|
value_cache_ptr_offset = (
|
||||||
)
|
value_cache_ptr
|
||||||
tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask)
|
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||||
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
|
+ slot_id * num_heads * head_size
|
||||||
|
)
|
||||||
|
|
||||||
def vllm_layout_trans(
|
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
|
||||||
b_query_lens_loc,
|
mask = (col_offsets + i) < head_size * num_heads
|
||||||
b_seq_lens_loc,
|
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
|
||||||
block_table,
|
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
|
||||||
k_cache,
|
if DEQUANT:
|
||||||
v_cache,
|
k_dtype = k_reg.dtype
|
||||||
max_seq_len,
|
v_dtype = v_reg.dtype
|
||||||
k_scale,
|
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
|
||||||
v_scale,
|
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
|
||||||
output_dtype,
|
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
|
||||||
total_tokens,
|
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
|
||||||
|
|
||||||
|
def cp_mha_gather_cache(
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
k_scales: torch.Tensor,
|
||||||
|
v_scales: torch.Tensor,
|
||||||
|
cu_seqlens_kv: torch.Tensor,
|
||||||
|
token_to_batch: torch.Tensor,
|
||||||
|
seq_starts: torch.Tensor,
|
||||||
|
dequant: bool,
|
||||||
|
kv_cache_layout: str,
|
||||||
|
total_tokens: int,
|
||||||
):
|
):
|
||||||
H_KV = v_cache.shape[2]
|
assert kv_cache_layout in ["v0", "NHD", "HND"], (
|
||||||
D = v_cache.shape[3]
|
"kv_cache_layout only support v0, NHD, HND"
|
||||||
BLOCK_SIZE = v_cache.shape[1]
|
|
||||||
|
|
||||||
k_values = torch.empty(
|
|
||||||
(total_tokens, H_KV, D),
|
|
||||||
dtype=output_dtype,
|
|
||||||
device=k_cache.device,
|
|
||||||
)
|
)
|
||||||
v_values = torch.empty(
|
head_dim = key.shape[2]
|
||||||
(total_tokens, H_KV, D),
|
x = 0
|
||||||
dtype=output_dtype,
|
# assert dequant is True, "Currently, we only support "\
|
||||||
device=v_cache.device,
|
# "gather cache with dequant"
|
||||||
|
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
|
||||||
|
assert kv_cache_layout == "NHD", (
|
||||||
|
"ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now"
|
||||||
)
|
)
|
||||||
|
assert head_dim == key_cache.shape[3], (
|
||||||
grid = (block_table.shape[0], (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
|
"We assume your kv cache layout is [num_blocks, "
|
||||||
|
"page_size, num_heads, head_dim], but got otherwise"
|
||||||
if output_dtype == torch.float16:
|
|
||||||
output_dtype = tl.float16
|
|
||||||
elif output_dtype == torch.bfloat16:
|
|
||||||
output_dtype = tl.bfloat16
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported output dtype: {output_dtype}")
|
|
||||||
|
|
||||||
_vllm_layout_trans_kernel[grid](
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
k_values,
|
|
||||||
v_values,
|
|
||||||
b_query_lens_loc,
|
|
||||||
b_seq_lens_loc,
|
|
||||||
block_table,
|
|
||||||
block_table.stride(0),
|
|
||||||
k_scale,
|
|
||||||
v_scale,
|
|
||||||
output_dtype=output_dtype,
|
|
||||||
E_DIM=H_KV * D,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
)
|
)
|
||||||
|
page_size = key_cache.shape[1]
|
||||||
|
num_heads = key_cache.shape[2]
|
||||||
|
|
||||||
return k_values, v_values
|
NUM_PRGMS = num_programs(total_tokens)
|
||||||
|
BLOCK_SIZE = block_size(key_cache, head_dim)
|
||||||
def flash_attn_varlen_func_impl(
|
grid = lambda meta: (NUM_PRGMS,)
|
||||||
q: torch.Tensor,
|
cp_mha_gather_cache_kernel[grid](
|
||||||
k_cache: torch.Tensor,
|
key_cache,
|
||||||
v_cache: torch.Tensor,
|
value_cache,
|
||||||
out: torch.Tensor,
|
key,
|
||||||
cu_seqlens_q: torch.Tensor,
|
value,
|
||||||
cu_seqlens_k: torch.Tensor,
|
block_tables,
|
||||||
max_seqlen_q: int,
|
cu_seqlens_kv,
|
||||||
max_seqlen_k: int,
|
token_to_batch,
|
||||||
softmax_scale: float,
|
seq_starts,
|
||||||
window_size: list[int] | None, # -1 means infinite context window
|
k_scales,
|
||||||
alibi_slopes: list[float] | None,
|
v_scales,
|
||||||
block_table: torch.Tensor,
|
num_heads,
|
||||||
k_scale: torch.Tensor,
|
head_dim,
|
||||||
v_scale: torch.Tensor,
|
x,
|
||||||
total_tokens: int = 0,
|
block_tables.size(1),
|
||||||
) -> torch.Tensor:
|
|
||||||
if total_tokens == 0:
|
|
||||||
total_tokens = int(cu_seqlens_k[-1].item())
|
|
||||||
k, v = vllm_layout_trans(
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
block_table,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
max_seqlen_k,
|
|
||||||
k_scale,
|
|
||||||
v_scale,
|
|
||||||
q.dtype,
|
|
||||||
total_tokens,
|
total_tokens,
|
||||||
|
DEQUANT=dequant,
|
||||||
|
PAGE_SIZE=page_size,
|
||||||
|
CACHE_FORMAT=kv_cache_layout,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
NUM_PRGMS=NUM_PRGMS,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = aiter.flash_attn_varlen_func(
|
|
||||||
q=q,
|
|
||||||
k=k,
|
|
||||||
v=v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
min_seqlen_q=1,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
alibi_slopes=alibi_slopes,
|
|
||||||
window_size=window_size,
|
|
||||||
out=out,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def flash_attn_varlen_func_fake(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k_cache: torch.Tensor,
|
|
||||||
v_cache: torch.Tensor,
|
|
||||||
out: torch.Tensor,
|
|
||||||
cu_seqlens_q: torch.Tensor,
|
|
||||||
cu_seqlens_k: torch.Tensor,
|
|
||||||
max_seqlen_q: int,
|
|
||||||
max_seqlen_k: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size: list[int] | None, # -1 means infinite context window
|
|
||||||
alibi_slopes: list[float] | None,
|
|
||||||
block_table: torch.Tensor,
|
|
||||||
k_scale: torch.Tensor,
|
|
||||||
v_scale: torch.Tensor,
|
|
||||||
total_tokens: int = 0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.empty(
|
|
||||||
q.shape[0], q.shape[1], v_cache.shape[-2], dtype=q.dtype, device=q.device
|
|
||||||
)
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
"flash_attn_varlen_func",
|
|
||||||
flash_attn_varlen_func_impl,
|
|
||||||
["out"],
|
|
||||||
flash_attn_varlen_func_fake,
|
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AiterFlashAttentionDecodeMetadata:
|
||||||
|
max_query_len: int
|
||||||
|
min_query_len: int
|
||||||
|
max_seq_len: int
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AiterFlashAttentionPrefillMetadata:
|
||||||
|
max_query_len: int
|
||||||
|
min_query_len: int
|
||||||
|
max_seq_len: int
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AiterChunkContextMetadata:
|
||||||
|
workspace: torch.Tensor
|
||||||
|
cu_seq_lens_chunk: torch.Tensor
|
||||||
|
chunk_starts: torch.Tensor
|
||||||
|
token_to_batch: torch.Tensor
|
||||||
|
seq_tot: list[int]
|
||||||
|
max_seq_lens: list[int]
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
num_chunks: int
|
||||||
|
total_token_per_batch: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AiterFlashAttentionChunkPrefillMetadata:
|
||||||
|
max_query_len: int
|
||||||
|
min_query_len: int
|
||||||
|
max_seq_len: int
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
chunk_context_metadata: AiterChunkContextMetadata
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AiterFlashAttentionMetadata:
|
class AiterFlashAttentionMetadata:
|
||||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||||
@ -248,7 +229,18 @@ class AiterFlashAttentionMetadata:
|
|||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
cu_seq_lens: torch.Tensor | None
|
|
||||||
|
# prefill and deocde split
|
||||||
|
num_decodes: int
|
||||||
|
num_decode_tokens: int
|
||||||
|
num_prefills: int
|
||||||
|
num_prefill_tokens: int
|
||||||
|
num_extends: int
|
||||||
|
num_extend_tokens: int
|
||||||
|
|
||||||
|
decode_metadata: AiterFlashAttentionDecodeMetadata | None
|
||||||
|
prefill_metadata: AiterFlashAttentionPrefillMetadata | None
|
||||||
|
extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None
|
||||||
|
|
||||||
# For cascade attention.
|
# For cascade attention.
|
||||||
use_cascade: bool
|
use_cascade: bool
|
||||||
@ -260,6 +252,7 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
||||||
):
|
):
|
||||||
cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
|
reorder_batch_threshold: int = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -285,6 +278,12 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
self.aot_sliding_window: tuple[int, int] | None = None
|
self.aot_sliding_window: tuple[int, int] | None = None
|
||||||
self.total_tokens: int = 0
|
self.total_tokens: int = 0
|
||||||
|
|
||||||
|
self.extend_workspace = torch.empty(
|
||||||
|
[2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim],
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
def build_for_cudagraph_capture(
|
||||||
self, common_attn_metadata: CommonAttentionMetadata
|
self, common_attn_metadata: CommonAttentionMetadata
|
||||||
):
|
):
|
||||||
@ -302,42 +301,139 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False,
|
fast_build: bool = False,
|
||||||
) -> "AiterFlashAttentionMetadata":
|
) -> "AiterFlashAttentionMetadata":
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
split_ret = split_decodes_prefills_and_extends(
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
common_attn_metadata,
|
||||||
max_seq_len = common_attn_metadata.max_seq_len
|
decode_threshold=self.reorder_batch_threshold,
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
)
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
|
||||||
slot_mapping = common_attn_metadata.slot_mapping
|
|
||||||
if max_query_len > 1:
|
|
||||||
# We pre-compute cumulative seq len needed for prefill attention
|
|
||||||
# here to avoid recomputing it for every layer
|
|
||||||
cu_seq_lens = torch.zeros(
|
|
||||||
seq_lens.shape[0] + 1, dtype=torch.int32, device=seq_lens.device
|
|
||||||
)
|
|
||||||
torch.cumsum(seq_lens, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:])
|
|
||||||
num_actual_kv_tokens = int(cu_seq_lens[-1].item())
|
|
||||||
else:
|
|
||||||
cu_seq_lens = None
|
|
||||||
num_actual_kv_tokens = 0
|
|
||||||
|
|
||||||
def schedule(
|
(
|
||||||
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
num_decodes,
|
||||||
):
|
num_extends,
|
||||||
return None
|
num_prefills,
|
||||||
|
num_decode_tokens,
|
||||||
|
num_extend_tokens,
|
||||||
|
num_prefill_tokens,
|
||||||
|
) = split_ret
|
||||||
|
|
||||||
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
|
|
||||||
|
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||||
|
|
||||||
|
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||||
|
|
||||||
|
decode_metadata = None
|
||||||
|
if num_decodes > 0:
|
||||||
|
decode_metadata = AiterFlashAttentionDecodeMetadata(
|
||||||
|
max_query_len=query_lens_cpu[:num_decodes].max().item(),
|
||||||
|
min_query_len=query_lens_cpu[:num_decodes].min().item(),
|
||||||
|
max_seq_len=seq_lens[:num_decodes].max().item(),
|
||||||
|
query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1],
|
||||||
|
)
|
||||||
|
|
||||||
|
prefill_metadata = None
|
||||||
|
if num_prefills > 0:
|
||||||
|
query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :]
|
||||||
|
query_start_loc_device = common_attn_metadata.query_start_loc[
|
||||||
|
num_decodes + num_extends :
|
||||||
|
]
|
||||||
|
prefill_metadata = AiterFlashAttentionPrefillMetadata(
|
||||||
|
max_query_len=query_lens_for_prefill.max().item(),
|
||||||
|
min_query_len=query_lens_for_prefill.min().item(),
|
||||||
|
max_seq_len=seq_lens[num_decodes + num_extends :].max().item(),
|
||||||
|
query_start_loc=query_start_loc_device - query_start_loc_device[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
extend_metadata = None
|
||||||
|
if num_extends > 0:
|
||||||
|
num_extends_slice = slice(num_decodes, num_decodes + num_extends)
|
||||||
|
query_lens_for_extend = query_lens_cpu[num_extends_slice]
|
||||||
|
seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice]
|
||||||
|
computed_kv_lens = seq_lens_for_extend - query_lens_for_extend
|
||||||
|
|
||||||
|
# allocate the equal amount of workspace for
|
||||||
|
# each chunk prefill request
|
||||||
|
max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends
|
||||||
|
num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk)
|
||||||
|
|
||||||
|
chunk_starts = (
|
||||||
|
torch.arange(num_chunks, dtype=torch.int32)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(-1, num_extends)
|
||||||
|
* max_context_chunk
|
||||||
|
)
|
||||||
|
chunk_ends = torch.min(
|
||||||
|
computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk
|
||||||
|
)
|
||||||
|
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(
|
||||||
|
min=0
|
||||||
|
) # [num_chunks, num_extends]
|
||||||
|
cu_seq_lens_cpu = torch.zeros(
|
||||||
|
[num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True
|
||||||
|
)
|
||||||
|
torch.cumsum(
|
||||||
|
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
|
||||||
|
)
|
||||||
|
max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item()
|
||||||
|
|
||||||
|
range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :]
|
||||||
|
idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None]
|
||||||
|
idx_to_batch_tensor = idx_to_batch_tensor.sum(
|
||||||
|
dim=1
|
||||||
|
) # [num_chunks, max_cum_tokens]
|
||||||
|
token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1)
|
||||||
|
|
||||||
|
chunk_context_metadata = AiterChunkContextMetadata(
|
||||||
|
workspace=self.extend_workspace,
|
||||||
|
cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True),
|
||||||
|
chunk_starts=chunk_starts.to(self.device, non_blocking=True),
|
||||||
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
|
seq_lens=chunk_seq_lens,
|
||||||
|
token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True),
|
||||||
|
num_chunks=num_chunks,
|
||||||
|
total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
|
query_start_loc_device = common_attn_metadata.query_start_loc[
|
||||||
|
num_decodes : num_decodes + num_extends + 1
|
||||||
|
]
|
||||||
|
seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice]
|
||||||
|
cu_seq_lens = torch.zeros(
|
||||||
|
num_extends + 1, dtype=torch.int32, device=seq_lens_device.device
|
||||||
|
)
|
||||||
|
torch.cumsum(
|
||||||
|
seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]
|
||||||
|
)
|
||||||
|
extend_metadata = AiterFlashAttentionChunkPrefillMetadata(
|
||||||
|
max_query_len=query_lens_for_extend.max().item(),
|
||||||
|
min_query_len=query_lens_for_extend.min().item(),
|
||||||
|
max_seq_len=seq_lens[num_extends_slice].max().item(),
|
||||||
|
query_start_loc=query_start_loc_device - query_start_loc_device[0],
|
||||||
|
chunk_context_metadata=chunk_context_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_actual_kv_tokens = torch.sum(seq_lens).item()
|
||||||
|
|
||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
|
|
||||||
attn_metadata = AiterFlashAttentionMetadata(
|
attn_metadata = AiterFlashAttentionMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||||
num_actual_kv_tokens=num_actual_kv_tokens,
|
num_actual_kv_tokens=num_actual_kv_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=common_attn_metadata.max_query_len,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=common_attn_metadata.query_start_loc,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=common_attn_metadata.max_seq_len,
|
||||||
seq_lens=seq_lens,
|
seq_lens=common_attn_metadata.seq_lens,
|
||||||
block_table=block_table_tensor,
|
block_table=common_attn_metadata.block_table_tensor,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
cu_seq_lens=cu_seq_lens,
|
num_decodes=num_decodes,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_extends=num_extends,
|
||||||
|
num_extend_tokens=num_extend_tokens,
|
||||||
|
decode_metadata=decode_metadata,
|
||||||
|
prefill_metadata=prefill_metadata,
|
||||||
|
extend_metadata=extend_metadata,
|
||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
total_tokens=self.total_tokens,
|
total_tokens=self.total_tokens,
|
||||||
@ -401,6 +497,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
|||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
if block_size % 16 != 0:
|
if block_size % 16 != 0:
|
||||||
raise ValueError("Block size must be a multiple of 16.")
|
raise ValueError("Block size must be a multiple of 16.")
|
||||||
|
|
||||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
|
||||||
|
|
||||||
@ -449,6 +546,110 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
"FlashAttentionImpl"
|
"FlashAttentionImpl"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def extend_forward(
|
||||||
|
self,
|
||||||
|
attn_metadata: AiterFlashAttentionMetadata,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
min_seqlen_q: int,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
):
|
||||||
|
out, lse = aiter.flash_attn_varlen_func(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_q,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_q,
|
||||||
|
min_seqlen_q=min_seqlen_q,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
return_lse=True,
|
||||||
|
)
|
||||||
|
assert attn_metadata.extend_metadata is not None
|
||||||
|
chunk_context_metadata = attn_metadata.extend_metadata.chunk_context_metadata
|
||||||
|
num_chunks = chunk_context_metadata.num_chunks
|
||||||
|
workspace = chunk_context_metadata.workspace
|
||||||
|
cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk
|
||||||
|
max_seqlens = chunk_context_metadata.max_seq_lens
|
||||||
|
chunk_starts = chunk_context_metadata.chunk_starts
|
||||||
|
token_to_batch = chunk_context_metadata.token_to_batch
|
||||||
|
total_token_per_batch = chunk_context_metadata.total_token_per_batch
|
||||||
|
key_fetched, value_fetched = workspace[0], workspace[1]
|
||||||
|
chunked_output = None
|
||||||
|
chunked_lse = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
cp_mha_gather_cache(
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
key=key_fetched,
|
||||||
|
value=value_fetched,
|
||||||
|
block_tables=block_table,
|
||||||
|
k_scales=k_scale,
|
||||||
|
v_scales=v_scale,
|
||||||
|
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
|
||||||
|
token_to_batch=token_to_batch[chunk_idx],
|
||||||
|
seq_starts=chunk_starts[chunk_idx],
|
||||||
|
dequant=False,
|
||||||
|
kv_cache_layout="NHD",
|
||||||
|
total_tokens=total_token_per_batch[chunk_idx],
|
||||||
|
)
|
||||||
|
|
||||||
|
suf_out, suf_lse = aiter.flash_attn_varlen_func(
|
||||||
|
q=query,
|
||||||
|
k=key_fetched,
|
||||||
|
v=value_fetched,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_kv[chunk_idx],
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlens[chunk_idx],
|
||||||
|
min_seqlen_q=min_seqlen_q,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
return_lse=True,
|
||||||
|
)
|
||||||
|
if chunked_output is None:
|
||||||
|
chunked_output = suf_out
|
||||||
|
chunked_lse = suf_lse
|
||||||
|
else:
|
||||||
|
tmp_output = torch.empty_like(out)
|
||||||
|
tmp_lse = torch.empty_like(lse)
|
||||||
|
merge_attn_states(
|
||||||
|
output=tmp_output,
|
||||||
|
output_lse=tmp_lse,
|
||||||
|
prefix_output=chunked_output,
|
||||||
|
prefix_lse=chunked_lse,
|
||||||
|
suffix_output=suf_out,
|
||||||
|
suffix_lse=suf_lse,
|
||||||
|
)
|
||||||
|
chunked_output = tmp_output
|
||||||
|
chunked_lse = tmp_lse
|
||||||
|
|
||||||
|
merge_attn_states(
|
||||||
|
output=output,
|
||||||
|
prefix_output=chunked_output,
|
||||||
|
prefix_lse=chunked_lse,
|
||||||
|
suffix_output=out,
|
||||||
|
suffix_lse=lse,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -488,24 +689,25 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
# IMPORTANT!
|
# IMPORTANT!
|
||||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
# NOTE(woosuk): With piece-wise CUDA graphs, this method is
|
||||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
# executed in eager-mode PyTorch. Thus, we need to be careful
|
||||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
# about any CPU overhead in this method. For example, `view`
|
||||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
# and `slice` (or `[:n]`) operations are surprisingly slow even
|
||||||
|
# in the case they do not invoke any GPU ops.
|
||||||
# Minimize the PyTorch ops in this method as much as possible.
|
# Minimize the PyTorch ops in this method as much as possible.
|
||||||
# Whenever making a change in this method, please benchmark the
|
# Whenever making a change in this method, please benchmark the
|
||||||
# performance to make sure it does not introduce any overhead.
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
if self.kv_sharing_target_layer_name is None:
|
if self.kv_sharing_target_layer_name is None:
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
# is not padded. However, we don't need to do
|
||||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||||
# op uses the slot_mapping's shape to determine the number of
|
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||||
# actual tokens.
|
# to determine the number of actual tokens.
|
||||||
|
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
@ -521,69 +723,118 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||||
|
|
||||||
if not attn_metadata.use_cascade:
|
# decode:extend:prefill
|
||||||
cu_seqlens_q = attn_metadata.query_start_loc
|
query = query[:num_actual_tokens]
|
||||||
seqused_k = attn_metadata.seq_lens
|
key = key[:num_actual_tokens]
|
||||||
max_seqlen_q = attn_metadata.max_query_len
|
value = value[:num_actual_tokens]
|
||||||
max_seqlen_k = attn_metadata.max_seq_len
|
|
||||||
block_table = attn_metadata.block_table
|
|
||||||
|
|
||||||
if max_seqlen_q > 1:
|
output_actual_tokens = output[:num_actual_tokens]
|
||||||
torch.ops.vllm.flash_attn_varlen_func(
|
|
||||||
query[:num_actual_tokens],
|
num_decodes = attn_metadata.num_decodes
|
||||||
key_cache,
|
num_prefills = attn_metadata.num_prefills
|
||||||
value_cache,
|
num_extends = attn_metadata.num_extends
|
||||||
out=output[:num_actual_tokens],
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
max_seqlen_q=max_seqlen_q,
|
num_extend_tokens = attn_metadata.num_extend_tokens
|
||||||
max_seqlen_k=max_seqlen_k,
|
if not attn_metadata.use_cascade:
|
||||||
|
# calculate for pure prefills
|
||||||
|
if num_prefills > 0:
|
||||||
|
assert attn_metadata.prefill_metadata is not None
|
||||||
|
|
||||||
|
prefill_query = query[num_decode_tokens + num_extend_tokens :]
|
||||||
|
prefill_key = key[num_decode_tokens + num_extend_tokens :]
|
||||||
|
prefill_value = value[num_decode_tokens + num_extend_tokens :]
|
||||||
|
|
||||||
|
aiter.flash_attn_varlen_func(
|
||||||
|
q=prefill_query,
|
||||||
|
k=prefill_key,
|
||||||
|
v=prefill_value,
|
||||||
|
cu_seqlens_q=attn_metadata.prefill_metadata.query_start_loc,
|
||||||
|
cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc,
|
||||||
|
max_seqlen_q=attn_metadata.prefill_metadata.max_query_len,
|
||||||
|
max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len,
|
||||||
|
min_seqlen_q=attn_metadata.prefill_metadata.min_query_len,
|
||||||
|
dropout_p=0.0,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
alibi_slopes=self.alibi_slopes,
|
causal=True,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
block_table=block_table,
|
alibi_slopes=self.alibi_slopes,
|
||||||
cu_seqlens_k=attn_metadata.cu_seq_lens,
|
out=output_actual_tokens[num_decode_tokens + num_extend_tokens :],
|
||||||
k_scale=layer._k_scale,
|
|
||||||
v_scale=layer._v_scale,
|
|
||||||
total_tokens=attn_metadata.num_actual_kv_tokens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_, num_heads, head_size = query.shape
|
# calculate for extends
|
||||||
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
|
if num_extends > 0:
|
||||||
num_seqs = seqused_k.shape[0]
|
assert attn_metadata.extend_metadata is not None
|
||||||
max_num_partitions = (
|
extend_tokens_slice = slice(
|
||||||
max_seqlen_k + _PARTITION_SIZE_ROCM - 1
|
num_decode_tokens, num_decode_tokens + num_extend_tokens
|
||||||
) // _PARTITION_SIZE_ROCM
|
)
|
||||||
|
extend_querys = query[extend_tokens_slice]
|
||||||
|
extend_keys = key[extend_tokens_slice]
|
||||||
|
extend_values = value[extend_tokens_slice]
|
||||||
|
extend_outputs = output[extend_tokens_slice]
|
||||||
|
self.extend_forward(
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
query=extend_querys,
|
||||||
|
key=extend_keys,
|
||||||
|
value=extend_values,
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
output=extend_outputs,
|
||||||
|
cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc,
|
||||||
|
max_seqlen_q=attn_metadata.extend_metadata.max_query_len,
|
||||||
|
max_seqlen_k=attn_metadata.extend_metadata.max_seq_len,
|
||||||
|
min_seqlen_q=attn_metadata.extend_metadata.min_query_len,
|
||||||
|
block_table=attn_metadata.block_table[
|
||||||
|
num_decodes : num_decodes + num_extends
|
||||||
|
],
|
||||||
|
slot_mapping=attn_metadata.slot_mapping[
|
||||||
|
num_decodes : num_decodes + num_extends
|
||||||
|
],
|
||||||
|
k_scale=layer._k_scale,
|
||||||
|
v_scale=layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
workspace_buffer = torch.empty(
|
# calculate for decodes
|
||||||
(num_seqs * num_heads * max_num_partitions * head_size)
|
if num_decodes > 0:
|
||||||
* nbytes_per_qo_elem
|
assert attn_metadata.decode_metadata is not None
|
||||||
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
|
_, num_heads, head_size = query.shape
|
||||||
dtype=torch.uint8,
|
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
|
||||||
device=output.device,
|
num_seqs = attn_metadata.seq_lens.shape[0]
|
||||||
)
|
max_num_partitions = (
|
||||||
|
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
|
||||||
|
) // _PARTITION_SIZE_ROCM
|
||||||
|
|
||||||
torch.ops.aiter.paged_attention_v1(
|
workspace_buffer = torch.empty(
|
||||||
output[:num_actual_tokens],
|
(num_seqs * num_heads * max_num_partitions * head_size)
|
||||||
workspace_buffer,
|
* nbytes_per_qo_elem
|
||||||
query[:num_actual_tokens],
|
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
|
||||||
key_cache,
|
dtype=torch.uint8,
|
||||||
value_cache,
|
device=output.device,
|
||||||
self.scale,
|
)
|
||||||
block_table,
|
|
||||||
cu_seqlens_q,
|
torch.ops.aiter.paged_attention_v1(
|
||||||
seqused_k,
|
output[:num_decode_tokens],
|
||||||
max_seqlen_k,
|
workspace_buffer,
|
||||||
self.alibi_slopes,
|
query[:num_decode_tokens],
|
||||||
self.kv_cache_dtype,
|
key_cache,
|
||||||
"NHD",
|
value_cache,
|
||||||
self.logits_soft_cap,
|
self.scale,
|
||||||
layer._k_scale,
|
attn_metadata.block_table[:num_decodes],
|
||||||
layer._v_scale,
|
attn_metadata.query_start_loc[:num_decodes],
|
||||||
None,
|
attn_metadata.seq_lens[:num_decodes],
|
||||||
_PARTITION_SIZE_ROCM,
|
attn_metadata.max_seq_len,
|
||||||
)
|
self.alibi_slopes,
|
||||||
return output
|
self.kv_cache_dtype,
|
||||||
|
"NHD",
|
||||||
|
self.logits_soft_cap,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
None,
|
||||||
|
_PARTITION_SIZE_ROCM,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Cascade attention is not implemented for ROCM AITER"
|
"Cascade attention is not implemented for ROCM AITER"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|||||||
@ -728,6 +728,73 @@ def subclass_attention_backend(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def split_decodes_prefills_and_extends(
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
decode_threshold: int = 1,
|
||||||
|
) -> tuple[int, int, int, int, int, int]:
|
||||||
|
"""
|
||||||
|
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||||
|
requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
common_attn_metadata: CommonAttentionMetadata object containing the
|
||||||
|
batch metadata.
|
||||||
|
decode_threshold: The maximum query length to be considered a decode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
num_decodes: The number of decode requests.
|
||||||
|
num_extends: The number of extend requests.
|
||||||
|
num_prefills: The number of prefill requests.
|
||||||
|
num_decode_tokens: The number of tokens in the decode requests.
|
||||||
|
num_extend_tokens: The number of tokens in the extend requests.
|
||||||
|
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||||
|
"""
|
||||||
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
|
num_tokens = common_attn_metadata.num_actual_tokens
|
||||||
|
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||||
|
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||||
|
|
||||||
|
if max_query_len <= decode_threshold:
|
||||||
|
return num_reqs, 0, 0, num_tokens, 0, 0
|
||||||
|
|
||||||
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||||
|
is_prefill_or_extend = query_lens > decode_threshold
|
||||||
|
is_prefill = (seq_lens == query_lens) & is_prefill_or_extend
|
||||||
|
first_extend = is_prefill_or_extend.int().argmax(dim=-1).item()
|
||||||
|
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||||
|
num_decodes = first_extend
|
||||||
|
num_decode_tokens = query_start_loc[first_extend].item()
|
||||||
|
if not torch.any(is_prefill_or_extend):
|
||||||
|
return (num_decodes, 0, 0, num_decode_tokens, 0, 0)
|
||||||
|
|
||||||
|
num_prefills_or_extends = num_reqs - num_decodes
|
||||||
|
num_prefill_or_extend_tokens = num_tokens - num_decode_tokens
|
||||||
|
if not torch.any(is_prefill):
|
||||||
|
return (
|
||||||
|
num_decodes,
|
||||||
|
num_prefills_or_extends,
|
||||||
|
0,
|
||||||
|
num_decode_tokens,
|
||||||
|
num_prefill_or_extend_tokens,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_extends = first_prefill - num_decodes
|
||||||
|
num_prefills = num_reqs - first_prefill
|
||||||
|
|
||||||
|
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
|
||||||
|
num_extend_tokens = num_prefill_or_extend_tokens - num_prefill_tokens
|
||||||
|
return (
|
||||||
|
num_decodes,
|
||||||
|
num_extends,
|
||||||
|
num_prefills,
|
||||||
|
num_decode_tokens,
|
||||||
|
num_extend_tokens,
|
||||||
|
num_prefill_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def split_decodes_and_prefills(
|
def split_decodes_and_prefills(
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
decode_threshold: int = 1,
|
decode_threshold: int = 1,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user