From d4b502a73af9a45b791274e5c7e11688910b1204 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 2 Jun 2025 19:14:19 +0000 Subject: [PATCH] mla format Signed-off-by: Sage Moore --- vllm/v1/attention/backends/mla/common.py | 35 ++++++++++++------------ 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ad0d0b319fcc4..ac231db7d8b1c 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -206,12 +206,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, UnquantizedLinearMethod) from vllm.platforms import current_platform from vllm.utils import cdiv, round_down -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + slice_query_start_locs) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable -from vllm.v1.attention.backends.utils import slice_query_start_locs - try: from vllm.vllm_flash_attn import flash_attn_varlen_func is_vllm_fa = True @@ -442,8 +441,10 @@ class MLACommonMetadataBuilder(Generic[M]): block_table=block_table_tensor, seq_lens=seq_lens, ) - - def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int, num_tokens: int, query_start_loc: torch.Tensor): + + def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int, + num_tokens: int, + query_start_loc: torch.Tensor): """ return - num_decodes: number of decode requests @@ -462,16 +463,17 @@ class MLACommonMetadataBuilder(Generic[M]): num_prefills = num_reqs - num_decodes num_decode_tokens = first_prefill num_prefill_tokens = num_tokens - query_start_loc[first_prefill] - return ( - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens - ) - - def build_slice(self, req_slice: slice, - token_slice: slice, - max_query_len: int, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - ) -> M: + return (num_decodes, num_prefills, num_decode_tokens, + num_prefill_tokens) + + def build_slice( + self, + req_slice: slice, + token_slice: slice, + max_query_len: int, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + ) -> M: num_reqs = req_slice.stop - req_slice.start num_tokens = token_slice.stop - token_slice.start @@ -481,14 +483,13 @@ class MLACommonMetadataBuilder(Generic[M]): device = self.runner.device block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[req_slice] - # print(f"num_reqs: {num_reqs} bloc_table_shape: {block_table_tensor.shape}") slot_mapping = block_table.slot_mapping_cpu[token_slice].to( device, non_blocking=True).long() query_start_loc = slice_query_start_locs( common_attn_metadata.query_start_loc, req_slice) seq_lens = common_attn_metadata.seq_lens[req_slice] - + num_computed_tokens = self.runner.input_batch.\ num_computed_tokens_cpu_tensor[req_slice]