misc merge fixes

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-25 19:47:33 +00:00
parent ee70ce0e4e
commit b9ad5e4588
4 changed files with 24 additions and 18 deletions

View File

@ -563,8 +563,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
scheduler_output,
decode_threshold=1)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
def _build_decode(self,
block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor,
ubatch_id: Optional[int] = None):
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
@ -597,7 +599,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> M:
fast_build: bool = False,
ubatch_id: Optional[int] = None) -> M:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
@ -720,7 +723,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens=seq_lens[:num_decodes],
)
ubatch_id=ubatch_id)
attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs,

View File

@ -67,8 +67,11 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor,
ubatch_id: Optional[int] = None) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,

View File

@ -126,7 +126,7 @@ def _make_metadata_with_slice(
def split_attn_metadata(
ubatch_slices: list[UbatchSlice],
ubatch_slices: list[tuple[slice, slice]],
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
@ -136,8 +136,9 @@ def split_attn_metadata(
"""
results = []
for ubatch_slice in ubatch_slices:
results.append(
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
s = UbatchSlice(request_slice=ubatch_slice[0],
token_slice=ubatch_slice[1])
results.append(_make_metadata_with_slice(s, common_attn_metadata))
return results

View File

@ -52,7 +52,7 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
make_local_attention_virtual_batches)
make_local_attention_virtual_batches, split_attn_metadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec,
@ -878,17 +878,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
.slot_mapping.fill_(-1)
if ubatch_slices is not None:
for ubid, (req_slice, token_slice) in enumerate(ubatch_slices):
assert token_slice.stop > token_slice.start
common_attn_metadata_list = split_attn_metadata(
ubatch_slices, common_attn_metadata)
for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list):
assert common_attn_metadata.max_query_len == 1
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].
build_slice(
req_slice=req_slice,
token_slice=token_slice,
max_query_len=max(tokens[req_slice]),
self.attn_metadata_builders[kv_cache_group_id].build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))
ubatch_id=ubid))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i