mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[Attention] FlashAttention MLA cudagraph support (#23958)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
41183c1fe0
commit
620db1fc58
@ -61,6 +61,16 @@ backend_configs = {
|
|||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
},
|
},
|
||||||
specific_gpu_arch=(9, 0)),
|
specific_gpu_arch=(9, 0)),
|
||||||
|
# FlashAttention MLA on Hopper
|
||||||
|
"FlashAttentionMLA":
|
||||||
|
BackendConfig(name="FlashAttentionMLA",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||||
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(9, 0)),
|
||||||
# Cutlass MLA on Blackwell
|
# Cutlass MLA on Blackwell
|
||||||
"CutlassMLA":
|
"CutlassMLA":
|
||||||
BackendConfig(
|
BackendConfig(
|
||||||
@ -102,7 +112,7 @@ backend_configs = {
|
|||||||
test_params_full_cudagraph = []
|
test_params_full_cudagraph = []
|
||||||
|
|
||||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||||
MLA_backends = ["FlashMLA", "CutlassMLA"]
|
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
||||||
for mla_backend in MLA_backends:
|
for mla_backend in MLA_backends:
|
||||||
test_params_full_cudagraph.append(
|
test_params_full_cudagraph.append(
|
||||||
pytest.param(
|
pytest.param(
|
||||||
|
|||||||
@ -73,7 +73,6 @@ def create_and_prepopulate_kv_cache(
|
|||||||
kv_c_contexts: list[torch.Tensor],
|
kv_c_contexts: list[torch.Tensor],
|
||||||
k_pe_contexts: list[torch.Tensor],
|
k_pe_contexts: list[torch.Tensor],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
@ -87,7 +86,6 @@ def create_and_prepopulate_kv_cache(
|
|||||||
k_pe_contexts: List of key positional embedding context tensors
|
k_pe_contexts: List of key positional embedding context tensors
|
||||||
for each sequence
|
for each sequence
|
||||||
block_size: Size of each block
|
block_size: Size of each block
|
||||||
num_kv_heads: Number of KV heads (should be 1 for MLA)
|
|
||||||
head_size: Size of each head (latent dimension)
|
head_size: Size of each head (latent dimension)
|
||||||
dtype: Data type for the cache
|
dtype: Data type for the cache
|
||||||
device: Device to create the cache on
|
device: Device to create the cache on
|
||||||
@ -285,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
query_lens = batch_spec.query_lens
|
query_lens = batch_spec.query_lens
|
||||||
num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||||
vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
|
||||||
vllm_config.parallel_config)
|
|
||||||
head_size = vllm_config.model_config.get_head_size()
|
head_size = vllm_config.model_config.get_head_size()
|
||||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||||
block_size = vllm_config.cache_config.block_size
|
block_size = vllm_config.cache_config.block_size
|
||||||
@ -476,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
kv_c_contexts=kv_c_contexts,
|
kv_c_contexts=kv_c_contexts,
|
||||||
k_pe_contexts=k_pe_contexts,
|
k_pe_contexts=k_pe_contexts,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
|||||||
@ -62,6 +62,16 @@ backend_configs = {
|
|||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
},
|
},
|
||||||
specific_gpu_arch=(9, 0)),
|
specific_gpu_arch=(9, 0)),
|
||||||
|
# FlashAttention MLA on Hopper
|
||||||
|
"FlashAttentionMLA":
|
||||||
|
BackendConfig(name="FlashAttentionMLA",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||||
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(9, 0)),
|
||||||
# FA2
|
# FA2
|
||||||
"FA2":
|
"FA2":
|
||||||
BackendConfig(name="FA2",
|
BackendConfig(name="FA2",
|
||||||
|
|||||||
@ -443,11 +443,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
self.metadata_cls = metadata_cls \
|
self.metadata_cls = metadata_cls \
|
||||||
if metadata_cls is not None else MLACommonMetadata
|
if metadata_cls is not None else MLACommonMetadata
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.device = device
|
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
self.device = device
|
||||||
|
|
||||||
self.num_heads = self.model_config.get_num_attention_heads(
|
self.num_heads = self.model_config.get_num_attention_heads(
|
||||||
parallel_config)
|
parallel_config)
|
||||||
self.mla_dims = get_mla_dims(self.model_config)
|
self.mla_dims = get_mla_dims(self.model_config)
|
||||||
@ -608,10 +610,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
prefill.prefill_main = self._fi_prefill_main
|
prefill.prefill_main = self._fi_prefill_main
|
||||||
prefill.prefill_chunks = self._fi_prefill_chunks
|
prefill.prefill_chunks = self._fi_prefill_chunks
|
||||||
|
|
||||||
def _build_decode(
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
seq_lens_cpu: torch.Tensor,
|
||||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
seq_lens_device: torch.Tensor,
|
||||||
query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata:
|
query_start_loc_cpu: torch.Tensor,
|
||||||
|
query_start_loc_device: torch.Tensor,
|
||||||
|
num_decode_tokens: int) -> MLACommonDecodeMetadata:
|
||||||
return MLACommonDecodeMetadata(
|
return MLACommonDecodeMetadata(
|
||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
seq_lens=seq_lens_device,
|
seq_lens=seq_lens_device,
|
||||||
@ -624,11 +628,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
Currently, only decode is supported for full cudagraphs with MLA.
|
Currently, only decode is supported for full cudagraphs with MLA.
|
||||||
"""
|
"""
|
||||||
m = common_attn_metadata
|
m = common_attn_metadata
|
||||||
assert m.num_reqs == m.num_actual_tokens, \
|
assert m.num_reqs <= (m.num_actual_tokens *
|
||||||
|
self.reorder_batch_threshold), \
|
||||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||||
|
|
||||||
assert m.max_query_len == 1 # decode-only
|
assert m.max_query_len <= self.reorder_batch_threshold # decode only
|
||||||
|
|
||||||
return self.build(0, m)
|
return self.build(0, m)
|
||||||
|
|
||||||
@ -819,6 +824,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
seq_lens_device=seq_lens[:num_decodes],
|
seq_lens_device=seq_lens[:num_decodes],
|
||||||
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
|
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
|
||||||
query_start_loc_device=query_start_loc[:num_decodes + 1],
|
query_start_loc_device=query_start_loc[:num_decodes + 1],
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_metadata = self.metadata_cls(
|
attn_metadata = self.metadata_cls(
|
||||||
|
|||||||
@ -17,11 +17,16 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
|||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata,
|
MLACommonMetadata,
|
||||||
MLACommonMetadataBuilder)
|
MLACommonMetadataBuilder)
|
||||||
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# NOTE(matt): This is an arbitrary number, copied from
|
||||||
|
# woosuk's implementation in standard FlashAttention backend
|
||||||
|
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
|
||||||
|
|
||||||
|
|
||||||
class FlashAttnMLABackend(MLACommonBackend):
|
class FlashAttnMLABackend(MLACommonBackend):
|
||||||
|
|
||||||
@ -48,6 +53,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
|
|||||||
max_query_len: int
|
max_query_len: int
|
||||||
max_seq_len: int
|
max_seq_len: int
|
||||||
scheduler_metadata: Optional[torch.Tensor] = None
|
scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
|
max_num_splits: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -57,14 +63,41 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
|||||||
|
|
||||||
class FlashAttnMLAMetadataBuilder(
|
class FlashAttnMLAMetadataBuilder(
|
||||||
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
||||||
|
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
|
AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
|
||||||
reorder_batch_threshold: ClassVar[int] = 512
|
reorder_batch_threshold: ClassVar[int] = 512
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||||
FlashAttnMLAMetadata)
|
FlashAttnMLAMetadata)
|
||||||
|
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||||
self.fa_aot_schedule = (get_flash_attn_version() == 3)
|
self.fa_aot_schedule = (get_flash_attn_version() == 3)
|
||||||
|
|
||||||
|
self.use_full_cuda_graph = \
|
||||||
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
|
|
||||||
|
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
||||||
|
self.max_cudagraph_size = self.compilation_config.max_capture_size
|
||||||
|
|
||||||
|
if self.max_cudagraph_size > 992:
|
||||||
|
# This condition derives from FA3's internal heuristic.
|
||||||
|
# TODO(woosuk): Support larger cudagraph sizes.
|
||||||
|
raise ValueError(
|
||||||
|
"Capture size larger than 992 is not supported for "
|
||||||
|
"full cuda graph.")
|
||||||
|
|
||||||
|
self.scheduler_metadata = torch.zeros(
|
||||||
|
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
# When using cuda graph, we need to set the upper bound of the
|
||||||
|
# number of splits so that large enough intermediate buffers are
|
||||||
|
# pre-allocated during capture.
|
||||||
|
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||||
|
|
||||||
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
|
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
|
||||||
max_seq_len, causal):
|
max_seq_len, causal):
|
||||||
if self.fa_aot_schedule:
|
if self.fa_aot_schedule:
|
||||||
@ -81,14 +114,16 @@ class FlashAttnMLAMetadataBuilder(
|
|||||||
page_size=self.page_size,
|
page_size=self.page_size,
|
||||||
cu_seqlens_q=cu_query_lens,
|
cu_seqlens_q=cu_query_lens,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
|
num_splits=self.max_num_splits,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _build_decode(
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
seq_lens_cpu: torch.Tensor,
|
||||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
seq_lens_device: torch.Tensor,
|
||||||
query_start_loc_device: torch.Tensor
|
query_start_loc_cpu: torch.Tensor,
|
||||||
) -> FlashAttnMLADecodeMetadata:
|
query_start_loc_device: torch.Tensor,
|
||||||
|
num_decode_tokens: int) -> FlashAttnMLADecodeMetadata:
|
||||||
query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
|
query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
|
||||||
max_query_len = query_lens_cpu.max().item()
|
max_query_len = query_lens_cpu.max().item()
|
||||||
max_seq_len = seq_lens_cpu.max().item()
|
max_seq_len = seq_lens_cpu.max().item()
|
||||||
@ -102,6 +137,29 @@ class FlashAttnMLAMetadataBuilder(
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For FA3 + full cudagraph
|
||||||
|
max_num_splits = 0
|
||||||
|
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||||
|
n = scheduler_metadata.shape[0]
|
||||||
|
# Ensure the persistent buffer is large enough
|
||||||
|
assert n <= self.scheduler_metadata.shape[0], \
|
||||||
|
f"Scheduler metadata size {n} exceeds buffer size " + \
|
||||||
|
f"{self.scheduler_metadata.shape[0]}"
|
||||||
|
self.scheduler_metadata[:n] = scheduler_metadata
|
||||||
|
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||||
|
# metadata to guarantee the correctness. Otherwise, some thread
|
||||||
|
# blocks may use the invalid scheduler metadata and overwrite the
|
||||||
|
# output buffer.
|
||||||
|
self.scheduler_metadata[n:] = 0
|
||||||
|
scheduler_metadata = self.scheduler_metadata[:n]
|
||||||
|
|
||||||
|
if num_decode_tokens <= self.max_cudagraph_size:
|
||||||
|
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||||
|
# usage, because the intermediate buffers of size [num_splits,
|
||||||
|
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||||
|
# we only set num_splits when using cuda graphs.
|
||||||
|
max_num_splits = self.max_num_splits
|
||||||
|
|
||||||
return FlashAttnMLADecodeMetadata(
|
return FlashAttnMLADecodeMetadata(
|
||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
seq_lens=seq_lens_device,
|
seq_lens=seq_lens_device,
|
||||||
@ -109,6 +167,7 @@ class FlashAttnMLAMetadataBuilder(
|
|||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
scheduler_metadata=scheduler_metadata,
|
scheduler_metadata=scheduler_metadata,
|
||||||
|
max_num_splits=max_num_splits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -175,12 +234,17 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
|||||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||||
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
|
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
|
||||||
|
|
||||||
|
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
|
||||||
|
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
|
||||||
|
# to prevent invalid grid configuration during graph capture.
|
||||||
|
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
|
||||||
|
|
||||||
o = flash_attn_varlen_func(
|
o = flash_attn_varlen_func(
|
||||||
q=q_pe,
|
q=q_pe,
|
||||||
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||||
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
|
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
|
||||||
q_v=q_nope,
|
q_v=q_nope,
|
||||||
max_seqlen_q=attn_metadata.decode.max_query_len,
|
max_seqlen_q=max_seqlen_q,
|
||||||
cu_seqlens_q=attn_metadata.decode.query_start_loc,
|
cu_seqlens_q=attn_metadata.decode.query_start_loc,
|
||||||
max_seqlen_k=attn_metadata.decode.max_seq_len,
|
max_seqlen_k=attn_metadata.decode.max_seq_len,
|
||||||
seqused_k=attn_metadata.decode.seq_lens,
|
seqused_k=attn_metadata.decode.seq_lens,
|
||||||
@ -189,6 +253,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
|||||||
causal=True,
|
causal=True,
|
||||||
fa_version=3, # only version 3 is supported
|
fa_version=3, # only version 3 is supported
|
||||||
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
||||||
|
num_splits=attn_metadata.decode.max_num_splits,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._v_up_proj(o)
|
return self._v_up_proj(o)
|
||||||
|
|||||||
@ -62,7 +62,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||||
FlashMLAMetadata)
|
FlashMLAMetadata)
|
||||||
|
|
||||||
self.compilation_config = vllm_config.compilation_config
|
|
||||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||||
vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
|
|
||||||
@ -85,10 +84,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
def _build_decode(
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
seq_lens_cpu: torch.Tensor,
|
||||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
seq_lens_device: torch.Tensor,
|
||||||
query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata:
|
query_start_loc_cpu: torch.Tensor,
|
||||||
|
query_start_loc_device: torch.Tensor,
|
||||||
|
num_decode_tokens: int) -> FlashMLADecodeMetadata:
|
||||||
tile_scheduler_metadata, num_splits = \
|
tile_scheduler_metadata, num_splits = \
|
||||||
get_mla_metadata(
|
get_mla_metadata(
|
||||||
seq_lens_device,
|
seq_lens_device,
|
||||||
|
|||||||
@ -104,10 +104,12 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
def _build_decode(
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
seq_lens_cpu: torch.Tensor,
|
||||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
seq_lens_device: torch.Tensor,
|
||||||
query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata:
|
query_start_loc_cpu: torch.Tensor,
|
||||||
|
query_start_loc_device: torch.Tensor,
|
||||||
|
num_decode_tokens: int) -> AiterMLADecodeMetadata:
|
||||||
page_size = self.kv_cache_spec.block_size
|
page_size = self.kv_cache_spec.block_size
|
||||||
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
|
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
|
||||||
device = self.device
|
device = self.device
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user