diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index efc6db6d35107..ebf27c3c251b3 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -563,8 +563,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): scheduler_output, decode_threshold=1) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor): + def _build_decode(self, + block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor, + ubatch_id: Optional[int] = None): return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens, @@ -597,7 +599,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> M: + fast_build: bool = False, + ubatch_id: Optional[int] = None) -> M: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -720,7 +723,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens=seq_lens[:num_decodes], - ) + ubatch_id=ubatch_id) attn_metadata = self.metadata_cls( num_reqs=common_attn_metadata.num_reqs, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index d3e5300dbbd6b..6274552c62c6c 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -67,8 +67,11 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor, + ubatch_id: Optional[int] = None) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 16b99aab842b8..32d804e815c45 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -126,7 +126,7 @@ def _make_metadata_with_slice( def split_attn_metadata( - ubatch_slices: list[UbatchSlice], + ubatch_slices: list[tuple[slice, slice]], common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ @@ -136,8 +136,9 @@ def split_attn_metadata( """ results = [] for ubatch_slice in ubatch_slices: - results.append( - _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + s = UbatchSlice(request_slice=ubatch_slice[0], + token_slice=ubatch_slice[1]) + results.append(_make_metadata_with_slice(s, common_attn_metadata)) return results diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 20cb8e0a9d5f0..e125a1217e192 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -52,7 +52,7 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, - make_local_attention_virtual_batches) + make_local_attention_virtual_batches, split_attn_metadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, @@ -878,17 +878,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): .slot_mapping.fill_(-1) if ubatch_slices is not None: - for ubid, (req_slice, token_slice) in enumerate(ubatch_slices): - assert token_slice.stop > token_slice.start + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, common_attn_metadata) + for ubid, common_attn_metadata in enumerate( + common_attn_metadata_list): + assert common_attn_metadata.max_query_len == 1 attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id]. - build_slice( - req_slice=req_slice, - token_slice=token_slice, - max_query_len=max(tokens[req_slice]), + self.attn_metadata_builders[kv_cache_group_id].build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - )) + ubatch_id=ubid)) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][layer_name] = attn_metadata_i