diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 0bbad17d7ebc7..c9a80e9f7317d 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) state_indices_tensor = attn_metadata.state_indices_tensor has_initial_states_p = attn_metadata.has_initial_states_p + query_start_loc_p = attn_metadata.query_start_loc_p BCx, _ = self.in_proj(hidden_states) @@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp): [num_decodes, num_prefills], dim=0, ) - query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes - if has_prefill - else None - ) conv_output_list = [] diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index fcda6134016ba..47dd44601377b 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -3,17 +3,11 @@ from dataclasses import dataclass -import torch - from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import VllmConfig -from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder -from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, - split_decodes_and_prefills, +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadata, + BaseMambaAttentionMetadataBuilder, ) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class Mamba1AttentionBackend(AttentionBackend): @@ -23,137 +17,12 @@ class Mamba1AttentionBackend(AttentionBackend): @dataclass -class Mamba1AttentionMetadata: - query_start_loc_p: torch.Tensor - state_indices_tensor: torch.Tensor - has_initial_states_p: torch.Tensor | None - num_prefills: int - num_prefill_tokens: int - num_decodes: int - num_decode_tokens: int - - block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] - block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,] - block_idx_last_computed_token: torch.Tensor # shape: [batch,] - num_computed_tokens_p: torch.Tensor # shape: [batch,] +class Mamba1AttentionMetadata(BaseMambaAttentionMetadata): + pass class Mamba1AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata] ): - def __init__( - self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - ): - super().__init__(kv_cache_spec, layer_names, vllm_config, device) - assert isinstance(kv_cache_spec, MambaSpec) - - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> Mamba1AttentionMetadata: - num_reqs = common_attn_metadata.num_reqs - - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold - ) - ) - - has_initial_states_p = None - query_start_loc_p = None - num_computed_tokens, num_computed_tokens_p = None, None - block_idx_first_scheduled_token = None - block_idx_first_scheduled_token_p = None - - # TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here. - # We should consolidate this code - if self.vllm_config.cache_config.enable_prefix_caching: - # Return a tensor of shape (#requests, #max blocks) - state_indices_tensor = common_attn_metadata.block_table_tensor - mamba_block_size = self.kv_cache_spec.block_size - num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( - self.device - ) - ( - block_idx_last_computed_token, - block_idx_first_scheduled_token, - block_idx_last_scheduled_token, - ) = self._compute_prefix_caching_block_indices( - common_attn_metadata, mamba_block_size - ) - else: - # Always return just a single block per each request: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] - block_idx_last_scheduled_token = None - block_idx_last_computed_token = None - - if num_prefills > 0: - query_start_loc_p = ( - common_attn_metadata.query_start_loc[-num_prefills - 1 :] - - num_decode_tokens - ) - has_initial_states_cpu = ( - common_attn_metadata.num_computed_tokens_cpu[ - num_reqs - num_prefills : num_reqs - ] - > 0 - ) - has_initial_states_p = has_initial_states_cpu.to( - common_attn_metadata.query_start_loc.device - ) - - if self.vllm_config.cache_config.enable_prefix_caching: - assert num_computed_tokens is not None - num_computed_tokens_p = num_computed_tokens[ - num_reqs - num_prefills : num_reqs - ] - assert block_idx_first_scheduled_token is not None - block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ - num_reqs - num_prefills : num_reqs - ] - - elif ( - num_decodes > 0 - and num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.cudagraph_mode.has_full_cudagraphs() - ): - self.state_indices_tensor[:num_decodes].copy_( - state_indices_tensor, non_blocking=True - ) - state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] - state_indices_tensor[num_decodes:] = PAD_SLOT_ID - - if self.vllm_config.cache_config.enable_prefix_caching: - self.block_idx_last_scheduled_token[:num_decodes].copy_( - block_idx_last_scheduled_token, non_blocking=True - ) - block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ - :num_decode_tokens - ] - - self.block_idx_last_computed_token[:num_decodes].copy_( - block_idx_last_computed_token, non_blocking=True - ) - block_idx_last_computed_token = self.block_idx_last_computed_token[ - :num_decode_tokens - ] - - return Mamba1AttentionMetadata( - query_start_loc_p=query_start_loc_p, - has_initial_states_p=has_initial_states_p, - state_indices_tensor=state_indices_tensor, - num_prefills=num_prefills, - num_prefill_tokens=num_prefill_tokens, - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - block_idx_last_scheduled_token=block_idx_last_scheduled_token, - block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, - block_idx_last_computed_token=block_idx_last_computed_token, - num_computed_tokens_p=num_computed_tokens_p, - ) + metadata_cls = Mamba1AttentionMetadata + supports_update_block_table: bool = False diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index f923371283aa0..b526f0a329972 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -1,19 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import copy import itertools -from dataclasses import dataclass +from dataclasses import dataclass, replace import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadata, + BaseMambaAttentionMetadataBuilder, +) from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, - compute_causal_conv1d_metadata, - split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -94,48 +94,26 @@ class Mamba2AttentionBackend(AttentionBackend): @dataclass -class Mamba2AttentionMetadata: - num_prefills: int - num_prefill_tokens: int - num_decodes: int - num_decode_tokens: int - query_start_loc_p: torch.Tensor - seq_lens: torch.Tensor - - prep_initial_states: bool - chunk_size: int - - # The following tensors only contain prefill requests and will be None if - # the batch has no prefill request. - has_initial_states_p: torch.Tensor | None - seq_idx_p: torch.Tensor | None +class Mamba2AttentionMetadata(BaseMambaAttentionMetadata): + prep_initial_states: bool = False + chunk_size: int = 0 + # Chunk-related metadata (only for prefill) + seq_idx_p: torch.Tensor | None = None # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for # each chunk, its offests into the varlen sequence dimension. It is defined # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to # cu_chunk_seqlen_p[i+1]. - cu_chunk_seqlen_p: torch.Tensor | None - + cu_chunk_seqlen_p: torch.Tensor | None = None # last_chunk_indices_p is a tensor of shape (batch,) that contains the # index of the last chunk for every sequence in the (prefill) batch. - last_chunk_indices_p: torch.Tensor | None - - state_indices_tensor: torch.Tensor # shape: [batch,] - block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] - block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,] - block_idx_last_computed_token: torch.Tensor # shape: [batch,] - num_computed_tokens_p: torch.Tensor # shape: [batch,] - - # The following attributes are for triton implementation of causal_conv1d - nums_dict: dict | None = None - batch_ptr: torch.Tensor | None = None - token_chunk_offset_ptr: torch.Tensor | None = None + last_chunk_indices_p: torch.Tensor | None = None class Mamba2AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata] ): - supports_update_block_table: bool = True + metadata_cls = Mamba2AttentionMetadata def __init__( self, @@ -150,87 +128,93 @@ class Mamba2AttentionMetadataBuilder( "chunk_size needs to be set in the model config for Mamba2 models" ) + def _compute_chunk_metadata( + self, + num_prefills: int, + num_computed_tokens_p_cpu: torch.Tensor, + query_start_loc_p_cpu: torch.Tensor, + ) -> tuple[list[int], list[int], list[int]]: + """ + Compute chunk-specific metadata for Mamba2. + + The code below carefully constructs the chunks such that: + 1. Chunks contain tokens from a *single* sequence only. + 2. For every sequence, we are guaranteed that we can + retrieve the mamba state *every* chunk_size tokens. + Constraint (1) dramatically simplifies the mamba2 kernels. + Constraint (2) dramatically simplifies the implementation + of prefix caching for mamba2 (wip). We need to take care + of the interaction with chunked prefill in order to + satisfy constraint (2). + """ + # TODO (tdoublep): This code could probably be optimized. + cu_chunk_seqlen = [] + seq_idx = [] + last_chunk_indices = [] + seqlen_pos = 0 + + for req_idx in range(num_prefills): + this_num_computed = num_computed_tokens_p_cpu[req_idx].item() + this_new_tokens = ( + query_start_loc_p_cpu[req_idx + 1].item() + - query_start_loc_p_cpu[req_idx].item() + ) + + # if computed tokens are not chunk-aligned, use the first + # chunk to finish it off + if this_num_computed % self.chunk_size != 0: + seq_idx.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + # how many tokens to finish the chunk? + chunk_len = ( + cdiv(this_num_computed, self.chunk_size) * self.chunk_size + - this_num_computed + ) + # we can only use at most this_new_tokens + chunk_len = min(chunk_len, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + n_chunks = cdiv(this_new_tokens, self.chunk_size) + for chunk in range(n_chunks): + seq_idx.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + chunk_len = min(self.chunk_size, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + assert this_new_tokens == 0 + last_chunk_indices.append(len(cu_chunk_seqlen) - 1) + + cu_chunk_seqlen.append(seqlen_pos) + + return cu_chunk_seqlen, seq_idx, last_chunk_indices + def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> Mamba2AttentionMetadata: - num_reqs = common_attn_metadata.num_reqs - seq_lens = common_attn_metadata.seq_lens + common = self._compute_common_metadata(common_attn_metadata) - query_start_loc_p = None seq_idx_p = None cu_chunk_seqlen_p = None last_chunk_indices_p = None - - # Need flags to indicate if there are initial states - has_initial_states_p = None prep_initial_states = False - # for causal_conv1d - nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - - num_computed_tokens, num_computed_tokens_p = None, None - block_idx_first_scheduled_token = None - block_idx_first_scheduled_token_p = None - - if self.vllm_config.cache_config.enable_prefix_caching: - # Return a tensor of shape (#requests, #max blocks) - state_indices_tensor = common_attn_metadata.block_table_tensor - # Additional cache-related varaiables: - mamba_block_size = self.kv_cache_spec.block_size - num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( - self.device - ) - ( - block_idx_last_computed_token, - block_idx_first_scheduled_token, - block_idx_last_scheduled_token, - ) = self._compute_prefix_caching_block_indices( - common_attn_metadata, mamba_block_size - ) - else: - # Always return just a single block per each request: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] - # Additional cache-related varaiables: - block_idx_last_scheduled_token = None - block_idx_last_computed_token = None - - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold - ) - ) - # Compute seq_idx for prefill only - if num_prefills > 0: - # [batch,] - has_initial_states_cpu = ( - common_attn_metadata.num_computed_tokens_cpu[ - num_reqs - num_prefills : num_reqs - ] - > 0 - ) - prep_initial_states = torch.any(has_initial_states_cpu).item() - has_initial_states_p = has_initial_states_cpu.to( - common_attn_metadata.query_start_loc.device + if common.num_prefills > 0: + prep_initial_states = ( + torch.any(common.has_initial_states_p).item() + if common.has_initial_states_p is not None + else False ) - query_start_loc_p = ( - common_attn_metadata.query_start_loc[-num_prefills - 1 :] - - num_decode_tokens - ) + num_reqs = common.num_reqs + num_prefills = common.num_prefills + num_decode_tokens = common.num_decode_tokens - if self.vllm_config.cache_config.enable_prefix_caching: - assert num_computed_tokens is not None - num_computed_tokens_p = num_computed_tokens[ - num_reqs - num_prefills : num_reqs - ] - assert block_idx_first_scheduled_token is not None - block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ - num_reqs - num_prefills : num_reqs - ] num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[ num_reqs - num_prefills : num_reqs ] @@ -239,137 +223,33 @@ class Mamba2AttentionMetadataBuilder( - num_decode_tokens ) - # The code below carefully constructs the chunks such that: - # 1. Chunks contain tokens from a *single* sequence only. - # 2. For every sequence, we are guaranteed that we can - # retrieve the mamba state *every* chunk_size tokens. - # Constraint (1) dramatically simplifies the mamba2 kernels. - # Constraint (2) dramatically simplifies the implementation - # of prefix caching for mamba2 (wip). We need to take care - # of the interaction with chunked prefill in order to - # satisfy constraint (2). - # TODO (tdoublep): This code could probably be optimized. - cu_chunk_seqlen = [] - seq_idx = [] - last_chunk_indices = [] - seqlen_pos = 0 - for req_idx in range(num_prefills): - this_num_computed = num_computed_tokens_p_cpu[req_idx].item() - this_new_tokens = ( - query_start_loc_p_cpu[req_idx + 1].item() - - query_start_loc_p_cpu[req_idx].item() - ) - - # if computed tokens are not chunk-aligned, use the first - # chunk to finish it off - if this_num_computed % self.chunk_size != 0: - seq_idx.append(req_idx) - cu_chunk_seqlen.append(seqlen_pos) - # how many tokens to finish the chunk? - chunk_len = ( - cdiv(this_num_computed, self.chunk_size) * self.chunk_size - - this_num_computed - ) - # we can only use at most this_new_tokens - chunk_len = min(chunk_len, this_new_tokens) - seqlen_pos += chunk_len - this_new_tokens -= chunk_len - - n_chunks = cdiv(this_new_tokens, self.chunk_size) - for chunk in range(n_chunks): - seq_idx.append(req_idx) - cu_chunk_seqlen.append(seqlen_pos) - chunk_len = min(self.chunk_size, this_new_tokens) - seqlen_pos += chunk_len - this_new_tokens -= chunk_len - - assert this_new_tokens == 0 - last_chunk_indices.append(len(cu_chunk_seqlen) - 1) - - cu_chunk_seqlen.append(seqlen_pos) + cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata( + num_prefills, + num_computed_tokens_p_cpu, + query_start_loc_p_cpu, + ) seq_idx_p = torch.as_tensor( - seq_idx, device=query_start_loc_p.device, dtype=torch.int32 + seq_idx, + device=common_attn_metadata.query_start_loc.device, + dtype=torch.int32, ) cu_chunk_seqlen_p = torch.as_tensor( - cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32 + cu_chunk_seqlen, + device=common_attn_metadata.query_start_loc.device, + dtype=torch.int32, ) last_chunk_indices_p = torch.as_tensor( - last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32 + last_chunk_indices, + device=common_attn_metadata.query_start_loc.device, + dtype=torch.int32, ) - nums_dict, batch_ptr, token_chunk_offset_ptr = ( - compute_causal_conv1d_metadata(query_start_loc_p) - ) - - elif ( - num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.cudagraph_mode.has_full_cudagraphs() - ): - self.state_indices_tensor[:num_decodes].copy_( - state_indices_tensor, non_blocking=True - ) - state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] - - if self.vllm_config.cache_config.enable_prefix_caching: - self.block_idx_last_scheduled_token[:num_decodes].copy_( - block_idx_last_scheduled_token, non_blocking=True - ) - block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ - :num_decode_tokens - ] - - self.block_idx_last_computed_token[:num_decodes].copy_( - block_idx_last_computed_token, non_blocking=True - ) - block_idx_last_computed_token = self.block_idx_last_computed_token[ - :num_decode_tokens - ] - - attn_metadata = Mamba2AttentionMetadata( - num_prefills=num_prefills, - num_prefill_tokens=num_prefill_tokens, - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - query_start_loc_p=query_start_loc_p, - seq_lens=seq_lens, + return replace( + common, prep_initial_states=prep_initial_states, chunk_size=self.chunk_size, - has_initial_states_p=has_initial_states_p, seq_idx_p=seq_idx_p, - state_indices_tensor=state_indices_tensor, cu_chunk_seqlen_p=cu_chunk_seqlen_p, last_chunk_indices_p=last_chunk_indices_p, - nums_dict=nums_dict, - batch_ptr=batch_ptr, - token_chunk_offset_ptr=token_chunk_offset_ptr, - block_idx_last_scheduled_token=block_idx_last_scheduled_token, - block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, - block_idx_last_computed_token=block_idx_last_computed_token, - num_computed_tokens_p=num_computed_tokens_p, ) - return attn_metadata - - def update_block_table( - self, - metadata: Mamba2AttentionMetadata, - blk_table: torch.Tensor, - slot_mapping: torch.Tensor, - ) -> Mamba2AttentionMetadata: - new_metadata = copy.copy(metadata) - prefix_caching = self.vllm_config.cache_config.enable_prefix_caching - state_indices_t = blk_table if prefix_caching else blk_table[:, 0] - num_reqs = blk_table.shape[0] - - # For CUDA graphs, copy to persistent buffer - if ( - metadata.num_prefills == 0 - and num_reqs <= self.decode_cudagraph_max_bs - and self.compilation_config.cudagraph_mode.has_full_cudagraphs() - ): - persistent_state_indices_t = self.state_indices_tensor[:num_reqs] - persistent_state_indices_t.copy_(state_indices_t, non_blocking=True) - state_indices_t = persistent_state_indices_t - - new_metadata.state_indices_tensor = state_indices_t - return new_metadata diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index a9705db59f19d..4f876d66da147 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import abc +import copy +from dataclasses import dataclass from typing import ClassVar, TypeVar import torch @@ -9,20 +11,52 @@ import torch from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec -M = TypeVar("M") +M = TypeVar("M", bound="BaseMambaAttentionMetadata") + + +@dataclass +class BaseMambaAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + num_reqs: int + + # The following tensors only contain prefill requests and will be None if + # the batch has no prefill request. + has_initial_states_p: torch.Tensor | None + query_start_loc_p: torch.Tensor | None + num_computed_tokens_p: torch.Tensor | None + + state_indices_tensor: torch.Tensor + + # The following tensors are only used for prefix caching and are None if disabled + block_idx_last_scheduled_token: torch.Tensor | None + block_idx_first_scheduled_token_p: torch.Tensor | None + block_idx_last_computed_token: torch.Tensor | None + + # The following attributes are for triton implementation of causal_conv1d + nums_dict: dict | None = None + batch_ptr: torch.Tensor | None = None + token_chunk_offset_ptr: torch.Tensor | None = None class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): + metadata_cls: type[M] reorder_batch_threshold: int = 1 _cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) + supports_update_block_table: bool = True def __init__( self, @@ -87,6 +121,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): return self.build(0, m) + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: + """ + Default build implementation for Mamba-like attention backends. + Subclasses (e.g., Mamba2) can override to add additional metadata. + """ + return self._compute_common_metadata(common_attn_metadata) + def _compute_prefix_caching_block_indices( self, common_attn_metadata: CommonAttentionMetadata, @@ -115,3 +161,147 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): block_idx_first_scheduled_token, block_idx_last_scheduled_token, ) + + def _compute_common_metadata( + self, + common_attn_metadata: CommonAttentionMetadata, + ) -> M: + """ + Compute metadata common to both Mamba1 and Mamba2. + """ + num_reqs = common_attn_metadata.num_reqs + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) + + # Need flags to indicate if there are initial states + has_initial_states_p = None + query_start_loc_p = None + num_computed_tokens = None + num_computed_tokens_p = None + + # for prefix caching + block_idx_first_scheduled_token = None + block_idx_first_scheduled_token_p = None + block_idx_last_computed_token = None + block_idx_last_scheduled_token = None + + # for causal_conv1d + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + + if self.vllm_config.cache_config.enable_prefix_caching: + # Return a tensor of shape (#requests, #max blocks) + state_indices_tensor = common_attn_metadata.block_table_tensor + # Additional cache-related varaiables: + mamba_block_size = self.kv_cache_spec.block_size + num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( + self.device + ) + ( + block_idx_last_computed_token, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + ) = self._compute_prefix_caching_block_indices( + common_attn_metadata, mamba_block_size + ) + else: + # Always return just a single block per each request: + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + if num_prefills > 0: + query_start_loc_p = ( + common_attn_metadata.query_start_loc[-num_prefills - 1 :] + - num_decode_tokens + ) + has_initial_states_cpu = ( + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + > 0 + ) + has_initial_states_p = has_initial_states_cpu.to( + common_attn_metadata.query_start_loc.device + ) + + nums_dict, batch_ptr, token_chunk_offset_ptr = ( + compute_causal_conv1d_metadata(query_start_loc_p) + ) + + if self.vllm_config.cache_config.enable_prefix_caching: + assert num_computed_tokens is not None + num_computed_tokens_p = num_computed_tokens[ + num_reqs - num_prefills : num_reqs + ] + assert block_idx_first_scheduled_token is not None + block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ + num_reqs - num_prefills : num_reqs + ] + elif ( + num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ): + self.state_indices_tensor[:num_decodes].copy_( + state_indices_tensor, non_blocking=True + ) + state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID + + if self.vllm_config.cache_config.enable_prefix_caching: + self.block_idx_last_scheduled_token[:num_decodes].copy_( + block_idx_last_scheduled_token, non_blocking=True + ) + block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ + :num_decode_tokens + ] + + self.block_idx_last_computed_token[:num_decodes].copy_( + block_idx_last_computed_token, non_blocking=True + ) + block_idx_last_computed_token = self.block_idx_last_computed_token[ + :num_decode_tokens + ] + + return self.metadata_cls( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + query_start_loc_p=query_start_loc_p, + has_initial_states_p=has_initial_states_p, + state_indices_tensor=state_indices_tensor, + block_idx_last_scheduled_token=block_idx_last_scheduled_token, + block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, + block_idx_last_computed_token=block_idx_last_computed_token, + num_computed_tokens_p=num_computed_tokens_p, + num_reqs=num_reqs, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, + ) + + def update_block_table( + self, + metadata: M, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> M: + new_metadata = copy.copy(metadata) + prefix_caching = self.vllm_config.cache_config.enable_prefix_caching + state_indices_t = blk_table if prefix_caching else blk_table[:, 0] + num_reqs = blk_table.shape[0] + + # For CUDA graphs, copy to persistent buffer + if ( + metadata.num_prefills == 0 + and num_reqs <= self.decode_cudagraph_max_bs + and self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ): + persistent_state_indices_t = self.state_indices_tensor[:num_reqs] + persistent_state_indices_t.copy_(state_indices_t, non_blocking=True) + state_indices_t = persistent_state_indices_t + + new_metadata.state_indices_tensor = state_indices_t + return new_metadata diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index c8fe0faf71088..e2fae37f5619d 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -2,15 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -import torch - from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder -from vllm.v1.attention.backends.utils import ( - PAD_SLOT_ID, - CommonAttentionMetadata, - compute_causal_conv1d_metadata, - split_decodes_and_prefills, +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadata, + BaseMambaAttentionMetadataBuilder, ) @@ -21,84 +16,11 @@ class ShortConvAttentionBackend(AttentionBackend): @dataclass -class ShortConvAttentionMetadata: - num_prefills: int - num_prefill_tokens: int - num_decodes: int - num_decode_tokens: int - - query_start_loc: torch.Tensor - state_indices_tensor: torch.Tensor - has_initial_states_p: torch.Tensor | None - - # For causal_conv1d - nums_dict: dict | None = None - batch_ptr: torch.Tensor | None = None - token_chunk_offset_ptr: torch.Tensor | None = None +class ShortConvAttentionMetadata(BaseMambaAttentionMetadata): + pass class ShortConvAttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata] ): - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> ShortConvAttentionMetadata: - num_reqs = common_attn_metadata.num_reqs - query_start_loc = common_attn_metadata.query_start_loc - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] - - # for causal_conv1d - nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold - ) - ) - - has_initial_states_p = None - if num_prefills > 0: - has_initial_states_cpu = ( - common_attn_metadata.num_computed_tokens_cpu[ - num_reqs - num_prefills : num_reqs - ] - > 0 - ) - has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device) - - query_start_loc_p = ( - common_attn_metadata.query_start_loc[-num_prefills - 1 :] - - num_decode_tokens - ) - - nums_dict, batch_ptr, token_chunk_offset_ptr = ( - compute_causal_conv1d_metadata(query_start_loc_p) - ) - - elif ( - num_decodes > 0 - and num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.cudagraph_mode.has_full_cudagraphs() - ): - self.state_indices_tensor[:num_decodes].copy_( - state_indices_tensor, non_blocking=True - ) - state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] - state_indices_tensor[num_decodes:] = PAD_SLOT_ID - - attn_metadata = ShortConvAttentionMetadata( - query_start_loc=query_start_loc, - state_indices_tensor=state_indices_tensor, - has_initial_states_p=has_initial_states_p, - num_prefills=num_prefills, - num_prefill_tokens=num_prefill_tokens, - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - nums_dict=nums_dict, - batch_ptr=batch_ptr, - token_chunk_offset_ptr=token_chunk_offset_ptr, - ) - return attn_metadata + metadata_cls = ShortConvAttentionMetadata