mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 01:07:03 +08:00
misc merge fixes
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
ee70ce0e4e
commit
b9ad5e4588
@ -563,8 +563,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor):
|
||||
def _build_decode(self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
ubatch_id: Optional[int] = None):
|
||||
return MLACommonDecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
@ -597,7 +599,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> M:
|
||||
fast_build: bool = False,
|
||||
ubatch_id: Optional[int] = None) -> M:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
@ -720,7 +723,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=seq_lens[:num_decodes],
|
||||
)
|
||||
ubatch_id=ubatch_id)
|
||||
|
||||
attn_metadata = self.metadata_cls(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
|
||||
@ -67,8 +67,11 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
ubatch_id: Optional[int] = None) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens,
|
||||
|
||||
@ -126,7 +126,7 @@ def _make_metadata_with_slice(
|
||||
|
||||
|
||||
def split_attn_metadata(
|
||||
ubatch_slices: list[UbatchSlice],
|
||||
ubatch_slices: list[tuple[slice, slice]],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[CommonAttentionMetadata]:
|
||||
"""
|
||||
@ -136,8 +136,9 @@ def split_attn_metadata(
|
||||
"""
|
||||
results = []
|
||||
for ubatch_slice in ubatch_slices:
|
||||
results.append(
|
||||
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||
s = UbatchSlice(request_slice=ubatch_slice[0],
|
||||
token_slice=ubatch_slice[1])
|
||||
results.append(_make_metadata_with_slice(s, common_attn_metadata))
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches)
|
||||
make_local_attention_virtual_batches, split_attn_metadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
@ -878,17 +878,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
.slot_mapping.fill_(-1)
|
||||
|
||||
if ubatch_slices is not None:
|
||||
for ubid, (req_slice, token_slice) in enumerate(ubatch_slices):
|
||||
assert token_slice.stop > token_slice.start
|
||||
common_attn_metadata_list = split_attn_metadata(
|
||||
ubatch_slices, common_attn_metadata)
|
||||
for ubid, common_attn_metadata in enumerate(
|
||||
common_attn_metadata_list):
|
||||
assert common_attn_metadata.max_query_len == 1
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].
|
||||
build_slice(
|
||||
req_slice=req_slice,
|
||||
token_slice=token_slice,
|
||||
max_query_len=max(tokens[req_slice]),
|
||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
ubatch_id=ubid))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user