mla format

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-02 19:14:19 +00:00
parent 44a595f6d6
commit d4b502a73a

View File

@ -206,12 +206,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
UnquantizedLinearMethod)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
slice_query_start_locs)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.utils import slice_query_start_locs
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
@ -442,8 +441,10 @@ class MLACommonMetadataBuilder(Generic[M]):
block_table=block_table_tensor,
seq_lens=seq_lens,
)
def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int, num_tokens: int, query_start_loc: torch.Tensor):
def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int,
num_tokens: int,
query_start_loc: torch.Tensor):
"""
return
- num_decodes: number of decode requests
@ -462,16 +463,17 @@ class MLACommonMetadataBuilder(Generic[M]):
num_prefills = num_reqs - num_decodes
num_decode_tokens = first_prefill
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
return (
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens
)
def build_slice(self, req_slice: slice,
token_slice: slice,
max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
) -> M:
return (num_decodes, num_prefills, num_decode_tokens,
num_prefill_tokens)
def build_slice(
self,
req_slice: slice,
token_slice: slice,
max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
) -> M:
num_reqs = req_slice.stop - req_slice.start
num_tokens = token_slice.stop - token_slice.start
@ -481,14 +483,13 @@ class MLACommonMetadataBuilder(Generic[M]):
device = self.runner.device
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[req_slice]
# print(f"num_reqs: {num_reqs} bloc_table_shape: {block_table_tensor.shape}")
slot_mapping = block_table.slot_mapping_cpu[token_slice].to(
device, non_blocking=True).long()
query_start_loc = slice_query_start_locs(
common_attn_metadata.query_start_loc, req_slice)
seq_lens = common_attn_metadata.seq_lens[req_slice]
num_computed_tokens = self.runner.input_batch.\
num_computed_tokens_cpu_tensor[req_slice]