diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 0b1f90e27db82..e60a86075b8bc 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model SSM_MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", - "yujiepan/mamba2-codestral-v0.1-tiny-random", + # mamba2-codestral in transformers is broken pending: + # https://github.com/huggingface/transformers/pull/40861 + #"yujiepan/mamba2-codestral-v0.1-tiny-random", ] HYBRID_MODELS = [ @@ -31,18 +33,7 @@ HYBRID_MODELS = [ "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", "LiquidAI/LFM2-1.2B", -] - -V1_SUPPORTED_MODELS = [ - "state-spaces/mamba-130m-hf", - "ai21labs/Jamba-tiny-dev", - "pfnet/plamo-2-1b", - "yujiepan/mamba2-codestral-v0.1-tiny-random", - "Zyphra/Zamba2-1.2B-instruct", - "hmellor/tiny-random-BambaForCausalLM", - "ibm-granite/granite-4.0-tiny-preview", - "tiiuae/Falcon-H1-0.5B-Base", - "LiquidAI/LFM2-1.2B", + "tiny-random/qwen3-next-moe", ] FULL_CUDA_GRAPH_MODELS = [ @@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [ "Zyphra/Zamba2-1.2B-instruct", ] -V0_UNSUPPORTED_MODELS = [ - "LiquidAI/LFM2-1.2B", -] - FP32_STATE_MODELS = [ "state-spaces/mamba-130m-hf", "Zyphra/Zamba2-1.2B-instruct", @@ -88,20 +75,16 @@ def test_models( hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - if model in V1_SUPPORTED_MODELS: - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - else: - vllm_v1_outputs = None + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - if model in V1_SUPPORTED_MODELS: - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v1_outputs, - name_0="hf", - name_1="vllm-v1", - ) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -299,14 +282,14 @@ def test_full_cuda_graph( example_prompts, max_tokens, num_logprobs) with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v1_outputs, + outputs_1_lst=vllm_outputs, name_0="hf", - name_1="vllm-v1", + name_1="vllm", ) @@ -340,12 +323,12 @@ def test_fp32_cache_state( with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, **{cache_dtype_param: "float32"}) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v1_outputs, + outputs_1_lst=vllm_outputs, name_0="hf", - name_1="vllm-v1", + name_1="vllm", ) diff --git a/tests/models/registry.py b/tests/models/registry.py index 6047a7a3e98dc..8b62952ad5908 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -312,14 +312,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), - "Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501 - trust_remote_code=True, - v0_only=True, - max_model_len=10240), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", - trust_remote_code=True), + max_transformers_version="4.55.4", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 + trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", max_transformers_version="4.53", transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 @@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", - min_transformers_version="4.56.2"), + extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501 + min_transformers_version="4.56.3"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 trust_remote_code=True, @@ -644,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { trust_remote_code=True, speculative_model="XiaomiMiMo/MiMo-7B-RL"), "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", - min_transformers_version="4.56.2"), + min_transformers_version="4.56.3"), } _TRANSFORMERS_BACKEND_MODELS = { diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index a524e13405807..6da62b5426bb6 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase): # Contains the KV cache (mamba state) for the layer # in the shape specified by `self.get_state_shape`. - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - kv_cache: list[Iterable[torch.Tensor]] + kv_cache: tuple[torch.Tensor, ...] @abstractmethod def get_state_shape(self) -> Iterable[tuple[int, ...]]: diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 5fe37a6289e01..6a901b47b8b63 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -15,7 +15,6 @@ import torch.nn.functional as F from einops import rearrange from torch import nn -from vllm import envs from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce @@ -42,8 +41,6 @@ if TYPE_CHECKING: import torch import torch.distributed -from vllm.model_executor.models.minimax_cache import MinimaxCacheParams - class MiniMaxText01RMSNormTP(CustomOp): name = "MiniMaxText01RMSNormTP" @@ -225,11 +222,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): self.tp_heads:(self.tp_rank + 1) * self.tp_heads].contiguous() - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self @staticmethod def weight_direct_load(param: torch.Tensor, @@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): break if _prefill_idx >= len(state_indices_tensor): break - # prefills are packed at end of batch in V1 - offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0 + offset = attn_metadata.num_decode_tokens _start = attn_metadata.query_start_loc[offset + _prefill_idx] _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] slot_id = state_indices_tensor[offset + _prefill_idx] @@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): hidden_decode = self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata) - if envs.VLLM_USE_V1: - hidden.insert(0, hidden_decode) - else: - hidden.append(hidden_decode) + hidden.insert(0, hidden_decode) if not hidden: return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) @@ -304,40 +296,28 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): - if not envs.VLLM_USE_V1: - q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - num_prefills = getattr(attn_metadata, "num_prefills", 0) - slot_id = state_indices_tensor[num_prefills:] - else: - q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[:attn_metadata.num_decodes] + q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[:attn_metadata.num_decodes] hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, slot_id, 32) return hidden def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor, - kv_caches: MinimaxCacheParams) -> None: - if not envs.VLLM_USE_V1: - self._forward(hidden_states, output, positions, kv_caches) - else: - torch.ops.vllm.linear_attention( - hidden_states, - output, - positions, - self.prefix, - ) + positions: torch.Tensor) -> None: + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[MinimaxCacheParams]) -> None: + positions: torch.Tensor) -> None: forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata - if envs.VLLM_USE_V1 and attn_metadata is not None: + if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, LinearAttentionMetadata) @@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): qkvact = torch.nn.functional.silu(qkv32) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - if envs.VLLM_USE_V1: - if attn_metadata is not None: - kv_cache = self.kv_cache[forward_context.virtual_engine][0] - state_indices_tensor = attn_metadata.state_indices_tensor + if attn_metadata is not None: + kv_cache = self.kv_cache[forward_context.virtual_engine][0] + state_indices_tensor = attn_metadata.state_indices_tensor - num_prefills = getattr(attn_metadata, "num_prefills", 0) - if num_prefills > 0: - num_decode_tokens = getattr(attn_metadata, - "num_decode_tokens", 0) - for prefill_idx in range(num_prefills): - q_start = attn_metadata.query_start_loc[ - num_decode_tokens + prefill_idx] - q_end = attn_metadata.query_start_loc[num_decode_tokens - + prefill_idx + - 1] - query_len = q_end - q_start - context_len = attn_metadata.seq_lens[ - num_decode_tokens + prefill_idx] - query_len - if context_len == 0: - block_to_clear = state_indices_tensor[ - num_decode_tokens + prefill_idx] - kv_cache[block_to_clear, ...] = 0 - else: - assert kv_caches is not None - kv_cache = kv_caches.minimax_cache - state_indices_tensor = kv_caches.state_indices_tensor + num_prefills = getattr(attn_metadata, "num_prefills", 0) + if num_prefills > 0: + num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", + 0) + for prefill_idx in range(num_prefills): + q_start = attn_metadata.query_start_loc[num_decode_tokens + + prefill_idx] + q_end = attn_metadata.query_start_loc[num_decode_tokens + + prefill_idx + 1] + query_len = q_end - q_start + context_len = attn_metadata.seq_lens[ + num_decode_tokens + prefill_idx] - query_len + if context_len == 0: + block_to_clear = state_indices_tensor[num_decode_tokens + + prefill_idx] + kv_cache[block_to_clear, ...] = 0 decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 if attn_metadata is None: @@ -410,8 +384,7 @@ def linear_attention( self = forward_context.no_compile_layers[layer_name] self._forward(hidden_states=hidden_states, output=output, - positions=positions, - kv_caches=None) + positions=positions) def linear_attention_fake( diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py deleted file mode 100644 index 7f376b70a7ae0..0000000000000 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ /dev/null @@ -1,177 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass -from typing import Optional, Union - -import numpy as np -import torch - -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionMetadata) -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.platforms import current_platform -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from vllm.v1.attention.backends.mamba2_attn import ( - Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets) - - -@dataclass -class Mamba2Metadata: - prep_initial_states: bool - chunk_size: int - - has_initial_states_p: torch.Tensor - seq_idx_p: torch.Tensor - chunk_indices_p: torch.Tensor - chunk_offsets_p: torch.Tensor - """ - With continuous batching layout of `x` in vLLM, to enable a Triton program - to handle a request in parallel, two supporting tensors are used - (batch_ptr, token_chunk_offset_ptr) - BLOCK_M = the # tokens to be handled by a Triton program - (can be customized for different hardware) - - nums_dict: - tracks the data associated with a given value of BLOCK_M - BLOCK_M = #tokens handled by a Triton program - cu_seqlen: total tokens per batch - (used as flag to update other data at each new input) - batch_ptr: tracks batch-id handled by the Triton program - token_chunk_offset_ptr: tracks token group_idx handled by the Triton program - (Triton implementation of causal_conv1d handles parallelism in 3-axes - - feature-axis - - batch-axis - - sequence-axis) - """ - nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.Tensor] = None - token_chunk_offset_ptr: Optional[torch.Tensor] = None - - -def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: - """Returns the appropriate metadata classes for the current platform.""" - if current_platform.is_rocm(): - from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) - from vllm.v1.attention.backends.triton_attn import ( - TritonAttentionMetadata) - return (AiterFlashAttentionMetadata, TritonAttentionMetadata, - PlaceholderAttentionMetadata) - if current_platform.is_cuda(): - from vllm.v1.attention.backends.flash_attn import ( - FlashAttentionMetadata) - from vllm.v1.attention.backends.xformers import ( - XFormersAttentionMetadata) - return (FlashAttentionMetadata, XFormersAttentionMetadata, - PlaceholderAttentionMetadata) - raise ValueError( - f"Unsupported platform for Mamba2: {current_platform.device_type}") - - -def prepare_mamba2_metadata( - chunk_size: int, - attn_metadata: AttentionMetadata, -) -> Mamba2Metadata: - - # compute number of prefill and decode requests - # NOTE: in V0 we assume prefills are before decodes - num_prefills = attn_metadata.num_prefills - num_prefill_tokens = attn_metadata.num_prefill_tokens - - seq_idx_p = None - chunk_indices_p, chunk_offsets_p = None, None - # Need flags to indicate if there are initial states - # currently we really only support the FlashAttention backend - has_initial_states_p = None - prep_initial_states = False - - # Compute seq_idx, chunk_indices and chunk_offsets for prefill only - if num_prefills > 0: - attn_metadata_instances = get_platform_metadata_classes() - if (isinstance(attn_metadata, attn_metadata_instances) - and attn_metadata.context_lens_tensor is not None): - # precompute flag to avoid device syncs later in mamba2 layer - # forwards - # prep is only needed for mamba2 ssd prefill processing - has_initial_states_p = ( - attn_metadata.context_lens_tensor[:num_prefills] > 0) - prep_initial_states = torch.any(has_initial_states_p).item() - query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1] - seq_idx_p = torch.repeat_interleave(torch.arange( - num_prefills, dtype=torch.int32, device=query_start_loc_p.device), - query_start_loc_p.diff(), - output_size=num_prefill_tokens) - seq_idx_p.unsqueeze_(0) - - # We compute metadata for chunked prefill once at the top level model - # forward and reuse them in mamba layers. If not needed, they will be - # ignored inside mamba kernels. - if prep_initial_states: - chunk_indices_p, chunk_offsets_p = \ - _query_start_loc_to_chunk_indices_offsets( - query_start_loc_p, chunk_size, num_prefill_tokens) - - return Mamba2Metadata(has_initial_states_p=has_initial_states_p, - prep_initial_states=prep_initial_states, - chunk_size=chunk_size, - seq_idx_p=seq_idx_p, - chunk_indices_p=chunk_indices_p, - chunk_offsets_p=chunk_offsets_p) - - -def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor, - mamba2_metadata: Union[Mamba2Metadata, - Mamba2AttentionMetadata, - GDNAttentionMetadata]): - """ - this is triggered upon handling a new input at the first layer - """ - dim, cu_seqlen = x.shape - mamba2_metadata.cu_seqlen = cu_seqlen - seqlens = np.diff(query_start_loc.to('cpu')) - nums_dict = {} # type: ignore - for BLOCK_M in [8]: # cover all BLOCK_M values - nums = -(-seqlens // BLOCK_M) - nums_dict[BLOCK_M] = {} - nums_dict[BLOCK_M]['nums'] = nums - nums_dict[BLOCK_M]['tot'] = nums.sum().item() - mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) - nums_dict[BLOCK_M]['mlist'] = mlist - mlist_len = len(nums_dict[BLOCK_M]['mlist']) - nums_dict[BLOCK_M]['mlist_len'] = mlist_len - MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 - offsetlist = [] # type: ignore - for idx, num in enumerate(nums): - offsetlist.extend(range(num)) - offsetlist = torch.tensor(offsetlist, dtype=torch.int32) - nums_dict[BLOCK_M]['offsetlist'] = offsetlist - - if mamba2_metadata.batch_ptr is None: - # Update default value after class definition - #mamba2_metadata.MAX_NUM_PROGRAMS *= 2 - mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device='cuda') - mamba2_metadata.token_chunk_offset_ptr = torch.full( - (MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device='cuda') - else: - if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS: - mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_( - PAD_SLOT_ID) - mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore - MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) - - mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist) - mamba2_metadata.token_chunk_offset_ptr[ # type: ignore - 0:mlist_len].copy_(offsetlist) - nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr - nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = ( - mamba2_metadata.token_chunk_offset_ptr) # type: ignore - mamba2_metadata.nums_dict = nums_dict - return mamba2_metadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index e704bfd451bce..a56ee13a63804 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -10,8 +10,6 @@ import torch from torch import nn from torch.nn.parameter import Parameter -from vllm import envs -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) -from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp): has_weight=rms_norm_has_weight, ) if use_rms_norm else None - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - # The inner tuple is (conv_state, ssm_state) - self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The inner tuple is (conv_state, ssm_state) + self.kv_cache = (torch.tensor([]), torch.tensor([])) self.model_config = model_config self.cache_config = cache_config @@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp): discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) return discrete_time_step, B, C - def forward(self, - hidden_states: torch.Tensor, - output: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): - if not envs.VLLM_USE_V1: - CustomOp.forward(self, hidden_states, output, mamba_cache_params) - else: - torch.ops.vllm.mamba_mixer( - hidden_states, - output, - self.prefix, - ) + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor): + torch.ops.vllm.mamba_mixer( + hidden_states, + output, + self.prefix, + ) - def forward_native(self, - hidden_states: torch.Tensor, - output: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): + def forward_native(self, hidden_states: torch.Tensor, + output: torch.Tensor): pass - def forward_cuda(self, - hidden_states: torch.Tensor, - output: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): + def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): """ Run the Mamba-1 SSM pipeline. @@ -234,31 +216,18 @@ class MambaMixer(MambaBase, CustomOp): forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - mamba1_metadata = attn_metadata - assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) - query_start_loc = mamba1_metadata.query_start_loc - state_indices_tensor = mamba1_metadata.state_indices_tensor - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - has_initial_states = mamba1_metadata.has_initial_states - num_padded_decodes = mamba1_metadata.num_padded_decodes - else: - assert isinstance(attn_metadata, AttentionMetadata) - assert mamba_cache_params is not None - conv_state = mamba_cache_params.conv_state - ssm_state = mamba_cache_params.ssm_state - state_indices_tensor = mamba_cache_params.state_indices_tensor - query_start_loc = attn_metadata.query_start_loc - context_lens_tensor = attn_metadata.context_lens_tensor - has_initial_states = None - if context_lens_tensor is not None: - has_initial_states = context_lens_tensor > 0 - num_padded_decodes = attn_metadata.num_decode_tokens + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + mamba1_metadata = attn_metadata + assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) + query_start_loc = mamba1_metadata.query_start_loc + state_indices_tensor = mamba1_metadata.state_indices_tensor + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + has_initial_states = mamba1_metadata.has_initial_states + num_padded_decodes = mamba1_metadata.num_padded_decodes # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp): conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if envs.VLLM_USE_V1 and attn_metadata is None: + if attn_metadata is None: # V1 profile run hidden_states_BC = hidden_states_BC.contiguous() return self.out_proj(hidden_states_BC.transpose(-2, -1))[0] @@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp): out=scan_outputs_d) scan_outputs_d = scan_outputs_d.transpose(0, 1) - if envs.VLLM_USE_V1: - ssm_outputs.insert(0, scan_outputs_d) - else: - ssm_outputs.append(scan_outputs_d) + ssm_outputs.insert(0, scan_outputs_d) scan_outputs_combined = ssm_outputs[0] if len( ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) @@ -441,40 +407,27 @@ def split_batch_to_prefill_and_decode( num_decodes: int, num_padded_decodes: int, ) -> PrefillDecodeSplit: + num_actual_tokens = num_prefill_tokens + num_padded_decodes - if envs.VLLM_USE_V1: - # In v1, decode tokens come first, then prefill tokens. - hidden_states_BC_d, hidden_states_BC_p = torch.split( - hidden_states_BC[..., :num_actual_tokens], - [num_padded_decodes, num_prefill_tokens], - dim=-1) - gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], - [num_padded_decodes, num_prefill_tokens], - dim=-1) + # In v1, decode tokens come first, then prefill tokens. + hidden_states_BC_d, hidden_states_BC_p = torch.split( + hidden_states_BC[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1) + gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1) - # num_padded_decodes accounts for CUDA graph padding when applicable - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[:num_padded_decodes + num_prefills], - [num_padded_decodes, num_prefills], - dim=0) - query_start_loc_p = (query_start_loc[-num_prefills - 1:] - - num_padded_decodes if num_prefills > 0 else None) - has_initial_states_p = has_initial_states[-num_prefills:] if ( - has_initial_states is not None and num_prefills > 0) else None - else: - # In v0, prefill tokens come first, then decode tokens. - hidden_states_BC_p, hidden_states_BC_d = torch.split( - hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1) - gate_p, gate_d = torch.split(gate, - [num_prefill_tokens, num_decode_tokens], - dim=-1) - state_indices_tensor_p, state_indices_tensor_d = torch.split( - state_indices_tensor, [num_prefills, num_decodes], dim=0) - query_start_loc_p = (query_start_loc[:num_prefills + - 1] if num_prefills > 0 else None) - has_initial_states_p = has_initial_states[:num_prefills] if ( - has_initial_states is not None and num_prefills > 0) else None + # num_padded_decodes accounts for CUDA graph padding when applicable + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor[:num_padded_decodes + num_prefills], + [num_padded_decodes, num_prefills], + dim=0) + query_start_loc_p = (query_start_loc[-num_prefills - 1:] - + num_padded_decodes if num_prefills > 0 else None) + has_initial_states_p = has_initial_states[-num_prefills:] if ( + has_initial_states is not None and num_prefills > 0) else None return PrefillDecodeSplit( hidden_states_BC_p=hidden_states_BC_p, @@ -495,9 +448,7 @@ def mamba_mixer( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, - output=output, - mamba_cache_params=None) + self.forward_cuda(hidden_states=hidden_states, output=output) def mamba_mixer_fake( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 02e6a9138c05f..047ce4c4c43d0 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: import torch from torch import nn -from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, - update_metadata) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( @@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( LoaderFunction, composed_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp): self.use_rms_norm, eps=rms_norm_eps) - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - # The inner tuple is (conv_state, ssm_state) - self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The tuple is (conv_state, ssm_state) + self.kv_cache = (torch.tensor([]), torch.tensor([])) self.model_config = model_config self.cache_config = cache_config @@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, ): pass @@ -478,59 +468,43 @@ class MambaMixer2(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, ): - if not envs.VLLM_USE_V1: - CustomOp.forward(self, hidden_states, output, mamba_cache_params, - mamba2_metadata, mup_vector) - else: - torch.ops.vllm.mamba_mixer2( - hidden_states, - output, - self.prefix, - mup_vector, - ) + torch.ops.vllm.mamba_mixer2( + hidden_states, + output, + self.prefix, + mup_vector, + ) def forward_cuda( self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, ): forward_context = get_forward_context() - # mamba2_metadata contains metadata necessary for the mamba2 triton + # attn_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - mamba2_metadata = attn_metadata - assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - # conv_state = (..., dim, width-1) yet contiguous along 'dim' - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - state_indices_tensor = attn_metadata.state_indices_tensor - else: - conv_state = mamba_cache_params.conv_state - ssm_state = mamba_cache_params.ssm_state - state_indices_tensor = mamba_cache_params.state_indices_tensor - # Common members between V1 metadata and V0 metadata - if mamba2_metadata is not None: - has_initial_states_p = mamba2_metadata.has_initial_states_p - prep_initial_states = mamba2_metadata.prep_initial_states - chunk_size = mamba2_metadata.chunk_size - seq_idx_p = mamba2_metadata.seq_idx_p - chunk_indices_p = mamba2_metadata.chunk_indices_p - chunk_offsets_p = mamba2_metadata.chunk_offsets_p + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states_p + prep_initial_states = attn_metadata.prep_initial_states + chunk_size = attn_metadata.chunk_size + seq_idx_p = attn_metadata.seq_idx_p + chunk_indices_p = attn_metadata.chunk_indices_p + chunk_offsets_p = attn_metadata.chunk_offsets_p # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp): dim=-1, ) - if envs.VLLM_USE_V1 and attn_metadata is None: - # V1 profile run + if attn_metadata is None: + # profile run hidden_states_B_C = (hidden_states_B_C.transpose( 0, 1).clone().transpose(0, 1)).contiguous() hidden_states, _B, _C = split_hidden_states_B_C_fn( @@ -579,49 +553,27 @@ class MambaMixer2(MambaBase, CustomOp): has_decode = num_decodes > 0 num_actual_tokens = num_prefill_tokens + num_decodes - # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - if envs.VLLM_USE_V1: - hidden_states_B_C_d, hidden_states_B_C_p = torch.split( - hidden_states_B_C[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0, - ) - dt_d, dt_p = torch.split( - dt[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0, - ) - # Split along batch dimension - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[:num_actual_tokens], - [num_decodes, num_prefills], - dim=0, - ) - query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1:] - - num_decodes if has_prefill else None) - else: - hidden_states_B_C_p, hidden_states_B_C_d = torch.split( - hidden_states_B_C, - [num_prefill_tokens, num_decodes], - dim=0, - ) - dt_p, dt_d = torch.split( - dt, - [num_prefill_tokens, num_decodes], - dim=0, - ) - # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - state_indices_tensor, - [num_prefills, num_decodes], - dim=0, - ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + - 1] - if has_prefill else None) + hidden_states_B_C_d, hidden_states_B_C_p = torch.split( + hidden_states_B_C[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + dt_d, dt_p = torch.split( + dt[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor[:num_actual_tokens], + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp): dtype=hidden_states.dtype, device=hidden_states.device, ) - if envs.VLLM_USE_V1: - preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( - preallocated_ssm_out, - [num_decodes, num_prefill_tokens], - dim=0, - ) - else: - preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( - preallocated_ssm_out, - [num_prefill_tokens, num_decodes], - dim=0, - ) + preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( + preallocated_ssm_out, + [num_decodes, num_prefill_tokens], + dim=0, + ) # Process prefill requests if has_prefill: @@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp): # pointed to by "state_indices_tensor" x = hidden_states_B_C_p.transpose( 0, 1) # this is the form that causal-conv see - if mamba2_metadata.cu_seqlen is None: - mamba2_metadata = update_metadata(x, query_start_loc_p, - mamba2_metadata) hidden_states_B_C_p = causal_conv1d_fn( x, conv_weights, @@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp): conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - metadata=mamba2_metadata, + metadata=attn_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] @@ -806,8 +748,6 @@ def mamba_mixer2( self = forward_context.no_compile_layers[layer_name] self.forward_cuda(hidden_states=hidden_states, output=output, - mamba_cache_params=None, - mamba2_metadata=None, mup_vector=mup_vector) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index a6c1af91de421..677a4b9d87fc7 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -100,7 +100,6 @@ class MambaStateShapeCalculator: intermediate_size: int, state_size: int, conv_kernel: int, - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int]]: conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) @@ -108,11 +107,7 @@ class MambaStateShapeCalculator: temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) - # In V0, the conv_state shape was swapped during allocation in - # MambaCacheManager, but in V1 it needs to be determined here at the - # calculation level - if use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] + conv_state_shape = conv_state_shape[1], conv_state_shape[0] return conv_state_shape, temporal_state_shape @@ -126,7 +121,6 @@ class MambaStateShapeCalculator: head_dim: int, state_size: int, conv_kernel: int, - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: # if n_groups is not divisible by world_size, need to extend the shards # to ensure all groups needed by a head is sharded along with it @@ -137,8 +131,6 @@ class MambaStateShapeCalculator: # contiguous along 'dim' axis conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) - if not use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] # These are not TP-ed as they depend on A, dt_bias, D # - they are typically small @@ -153,12 +145,9 @@ class MambaStateShapeCalculator: tp_world_size: int, intermediate_size: int, conv_kernel: int, - use_v1: bool = True, ) -> tuple[tuple[int, int]]: conv_dim = divide(intermediate_size, tp_world_size) conv_state_shape = (conv_kernel - 1, conv_dim) - if not use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] return (conv_state_shape, ) @classmethod @@ -183,7 +172,6 @@ class MambaStateShapeCalculator: head_v_dim: int, conv_kernel_size: int, num_spec: int = 0, - use_v1: bool = True, ): conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads) conv_state_shape = ( @@ -191,11 +179,7 @@ class MambaStateShapeCalculator: conv_kernel_size - 1 + num_spec, ) - # In V0, the conv_state shape was swapped during allocation in - # MambaCacheManager, but in V1 it needs to be determined here at the - # calculation level - if use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] + conv_state_shape = conv_state_shape[1], conv_state_shape[0] temporal_state_shape = (divide(num_v_heads, tp_world_size), head_k_dim, head_v_dim) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 8cfd0962c5bfe..010fcdda156c2 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -420,9 +420,7 @@ def causal_conv1d_fn( x = x.to(conv_states.dtype) out = torch.empty_like(x) if metadata is not None: - cu_seqlen = metadata.cu_seqlen nums_dict = metadata.nums_dict - #x = metadata.x args = nums_dict batch_ptr = metadata.batch_ptr token_chunk_offset_ptr = metadata.token_chunk_offset_ptr @@ -926,7 +924,6 @@ def causal_conv1d_update( query_start_loc: Optional[torch.Tensor] = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, - metadata=None, validate_data=False, ): """ diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 335191a5c82c1..ffdcd702aab40 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: import torch -from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size @@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( @@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp): prefix=f"{prefix}.out_proj", ) - assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1") compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - self.kv_cache = [(torch.tensor([]), )] + self.kv_cache = (torch.tensor([]), ) self.model_config = model_config self.cache_config = cache_config @@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - conv_metadata: ShortConvAttentionMetadata, ): return @@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - conv_metadata: ShortConvAttentionMetadata, ): torch.ops.vllm.short_conv( hidden_states, @@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - conv_metadata: ShortConvAttentionMetadata, ): forward_context = get_forward_context() # ShortConvAttentionMetadata contains metadata necessary for the @@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp): if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] - conv_metadata = attn_metadata assert isinstance(attn_metadata, ShortConvAttentionMetadata) self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) @@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp): if has_prefill: Bx_p = (B_p * x_p).transpose(0, 1) - if conv_metadata.cu_seqlen is None: - conv_metadata = update_metadata(Bx_p, query_start_loc_p, - conv_metadata) Bx = causal_conv1d_fn(Bx_p, conv_weights, self.conv.bias, @@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp): conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - metadata=conv_metadata, + metadata=attn_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] @@ -248,9 +235,7 @@ def short_conv( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, - output=output, - conv_metadata=None) + self.forward_cuda(hidden_states=hidden_states, output=output) def short_conv_fake( diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 584981ef3ebfd..4a6154dc548aa 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -9,21 +9,17 @@ import torch from torch import nn from transformers import BambaConfig -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant) @@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module): hidden_states, residual) output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mamba(hidden_states, output) # Fully Connected hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) @@ -315,22 +306,10 @@ class BambaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -343,23 +322,11 @@ class BambaModel(nn.Module): residual = intermediate_tensors["residual"] residual = None - num_attn = 0 for i, layer in enumerate(self.layers): - if isinstance(layer, BambaAttentionDecoderLayer): - num_attn += 1 - - layer_mamba_cache_params = None - if isinstance(layer, - BambaMixerDecoderLayer) and mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - i - num_attn) - hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: @@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, head_dim=hf_config.mamba_d_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, if not lora_config else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = \ - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba - ) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/constant_size_cache.py b/vllm/model_executor/models/constant_size_cache.py deleted file mode 100644 index f03c58a12932f..0000000000000 --- a/vllm/model_executor/models/constant_size_cache.py +++ /dev/null @@ -1,137 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod -from typing import Any - -import torch - -from vllm.attention.backends.utils import PAD_SLOT_ID - - -class ConstantSizeCache(ABC): - """ - Abstract base class for managing constant size caches - like Mamba and Minimax. - """ - - def __init__(self, max_batch_size: int): - # Maps between the request id and a dict that maps between the seq_id - # and its index inside the cache - self.cache_indices_mapping: dict[str, dict[int, int]] = {} - self.free_cache_indices = list(range(max_batch_size)) - - @property - @abstractmethod - def cache(self) -> Any: - """Return the underlying cache tensor(s)""" - pass - - @abstractmethod - def _copy_cache(self, from_index: int, to_index: int): - """Copy cache data from one index to another""" - pass - - def current_run_tensors(self, **kwargs) -> tuple: - """ - Return the tensors for the current run's conv and ssm state. - """ - if "seqlen_agnostic_capture_inputs" not in kwargs: - # We get here only on Prefill/Eager mode runs - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - finished_requests_ids = kwargs["finished_requests_ids"] - - self._release_finished_requests(finished_requests_ids) - state_indices = self._prepare_current_run_cache( - request_ids_to_seq_ids, finished_requests_ids) - - state_indices_tensor = torch.as_tensor(state_indices, - dtype=torch.int32, - device="cuda") - cache_tensors = self.cache - else: - # CUDA graph capturing runs - cache_tensors, state_indices_tensor = kwargs[ - "seqlen_agnostic_capture_inputs"] - - return (cache_tensors, state_indices_tensor) - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant state_indices into the CUDA graph input buffer - """ - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - finished_requests_ids = kwargs["finished_requests_ids"] - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - assert "seqlen_agnostic_capture_inputs" in input_buffers - _, input_state_indices_buffer = input_buffers[ - "seqlen_agnostic_capture_inputs"] - - self._release_finished_requests(finished_requests_ids) - state_indices = self._prepare_current_run_cache( - request_ids_to_seq_ids, finished_requests_ids) - cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( - state_indices) - state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) - - input_state_indices_buffer.copy_( - torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Cache during the CUDA graph replay - runs. - """ - state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, - dtype=torch.int32, - device="cuda") - return (self.cache, state_indices_tensor) - - def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, - finished_requests_ids) -> int: - """ - Assign (req_id,seq_id) pair to a `destination_index` index, if - already occupied, move the occupying index to a free index. - """ - if cur_rid in finished_requests_ids: - # set as pad, do not allocate destination index - return PAD_SLOT_ID - elif cur_rid not in self.cache_indices_mapping: - destination_index = self.free_cache_indices.pop() - self.cache_indices_mapping[cur_rid] = {seq_id: destination_index} - return destination_index - elif seq_id not in (seq_ids2indices := - self.cache_indices_mapping[cur_rid]): - # parallel sampling , where n > 1, assume prefill have - # already happened, so we copy the - # existing cache into the siblings seq_ids caches - index_exists = next(iter(seq_ids2indices.values())) - # case of decoding n>1, copy prefill cache to decoding indices - destination_index = self.free_cache_indices.pop() - self._copy_cache(from_index=index_exists, - to_index=destination_index) - self.cache_indices_mapping[cur_rid][seq_id] = destination_index - return destination_index - else: - return self.cache_indices_mapping[cur_rid][seq_id] - - def _prepare_current_run_cache( - self, request_ids_to_seq_ids: dict[str, list[int]], - finished_requests_ids: list[str]) -> list[int]: - return [ - self._assign_seq_id_to_cache_index(req_id, seq_id, - finished_requests_ids) - for req_id, seq_ids in request_ids_to_seq_ids.items() - for seq_id in seq_ids - ] - - def _release_finished_requests(self, - finished_seq_groups_req_ids: list[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.cache_indices_mapping: - for seq_id in self.cache_indices_mapping[req_id]: - self.free_cache_indices.append( - self.cache_indices_mapping[req_id][seq_id]) - self.cache_indices_mapping.pop(req_id) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 83efdd2e433fc..f382018e2222c 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -8,21 +8,17 @@ import torch from torch import nn from transformers import FalconH1Config -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP @@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): output = torch.empty_like(hidden_states) self.mamba( hidden_states, output, - mamba_cache_params, - mamba2_metadata=mamba2_metadata, mup_vector=self.mup_vector, ) return output, residual @@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): residual = hidden_states @@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module): # Process input through the SSM branch. # FalconH1SSMDecoderLayer expects hidden_states, attn_metadata, - # residual, mamba_cache_params, and sequence_idx. + # residual, and sequence_idx. ssm_hidden, _ = self.mamba( hidden_states=hidden_states * self.ssm_in_multiplier, residual=residual, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, **kwargs, ) # Sum the outputs from both branches. @@ -464,25 +450,10 @@ class FalconH1Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # pass a sequence index tensor, that is required for - # proper continuous batching computation including - # chunked prefill - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds * self.embedding_multiplier @@ -495,14 +466,9 @@ class FalconH1Model(nn.Module): for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - layer_mamba_cache_params = None - if mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i) hidden_states = layer( positions=positions, hidden_states=hidden_states, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, head_dim=hf_config.mamba_d_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, prefix=maybe_prefix(prefix, "model")) self.tie_word_embeddings = config.tie_word_embeddings self.unpadded_vocab_size = config.vocab_size - self.mamba_cache: Optional[MambaCacheManager] = None if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size if get_pp_group().is_last_rank: @@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, **kwargs, ): - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager( - self.vllm_config, - self.config.num_hidden_layers, - *mamba_state_shape, - *mamba_state_dtype, - ) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.model( input_ids, positions, - mamba_cache_params, intermediate_tensors, inputs_embeds, ) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index e89a1a4a0f7d3..f5751fe47bb8b 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -9,19 +9,15 @@ import torch from torch import nn from transformers import GraniteMoeHybridConfig -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .granitemoe import GraniteMoeMoE from .granitemoeshared import GraniteMoeSharedMLP @@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mamba(hidden_states, output) hidden_states = residual + output * self.residual_multiplier residual = hidden_states @@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module): for i, layer in enumerate(self.layers): if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): num_attn += 1 - - layer_mamba_cache_params = None - if isinstance( - layer, - GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - i - num_attn) - - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata) + hidden_states, residual = layer(positions=positions, + hidden_states=hidden_states, + residual=residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, head_dim=hf_config.mamba_d_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, scale=1 / self.config.logits_scaling) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = ( - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba)) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 12a49029195ff..e8277e259bc5b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -9,7 +9,6 @@ import torch from torch import nn from transformers import JambaConfig -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaMLP as JambaMLP -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, @@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module): hidden_states, residual) output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output, mamba_cache_params) + self.mamba(hidden_states, output) # Fully Connected hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) @@ -333,7 +328,6 @@ class JambaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -348,24 +342,11 @@ class JambaModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - kv_cache_index = 0 - mamba_cache_index = 0 for layer in islice(self.layers, self.start_layer, self.end_layer): - layer_mamba_cache_params = None - if isinstance(layer, JambaAttentionDecoderLayer): - kv_cache_index += 1 - if isinstance(layer, - JambaMambaDecoderLayer) and mamba_cache_params: - current_state_layer = mamba_cache_index - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - current_state_layer) - mamba_cache_index += 1 + hidden_states, residual = layer(positions=positions, + hidden_states=hidden_states, + residual=residual) - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=layer_mamba_cache_params) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, if not lora_config else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - # NOTE: mamba_cache_params is not needed for v1 - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - state_shape = self.get_mamba_state_shape_from_config( - self.vllm_config) - state_dtype = self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_layers, *state_shape, - *state_dtype) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): @@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, intermediate_size=hf_config.mamba_expand * hidden_size, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=envs.VLLM_USE_V1, ) def compute_logits( diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index dd97afbeb668a..53c36e4e52d81 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn from transformers import Lfm2Config -from vllm import envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module): self.conv( hidden_states, output, - conv_metadata=None, ) hidden_states, residual = self.ffn_norm(output, residual) hidden_states = self.feed_forward(hidden_states) @@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int]]: """ Calculate shapes for LFM2's convolutional cache. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, tp_world_size=parallel_config.tensor_parallel_size, intermediate_size=hf_config.conv_dim, conv_kernel=hf_config.conv_L_cache, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, scheduler_config = vllm_config.scheduler_config assert (not cache_config.enable_prefix_caching ), "Lfm2 currently does not support prefix caching" - assert envs.VLLM_USE_V1, ( - "Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1") super().__init__() self.config = config diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 36141a5d50641..5bd268291c7d9 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -8,7 +8,6 @@ import torch from torch import nn from transformers import MambaConfig -from vllm import envs from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group @@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, IsAttentionFree, SupportsPP) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module): hidden_states, residual = self.norm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output, mamba_cache_params) + self.mixer(hidden_states, output) return output, residual @@ -134,7 +129,6 @@ class MambaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -151,17 +145,9 @@ class MambaModel(nn.Module): for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - - layer_cache_params = None - if mamba_cache_params is not None: - layer_cache_params = mamba_cache_params.at_layer_idx( - i - self.start_layer) - - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=layer_cache_params) + hidden_states, residual = layer(positions=positions, + hidden_states=hidden_states, + residual=residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - state_shape = self.get_mamba_state_shape_from_config( - self.vllm_config) - state_dtype = self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_layers, *state_shape, - *state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.backbone(input_ids, positions, mamba_cache_params, + hidden_states = self.backbone(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states @@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): tp_world_size=parallel_config.tensor_parallel_size, intermediate_size=hf_config.intermediate_size, state_size=hf_config.state_size, - conv_kernel=hf_config.conv_kernel, - use_v1=envs.VLLM_USE_V1) + conv_kernel=hf_config.conv_kernel) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs( diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 9c3108146d2e5..97e9c5785e726 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -8,16 +8,11 @@ import torch from torch import nn from transformers import MambaConfig -from vllm import envs -from vllm.attention.backends.abstract import AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, IsAttentionFree) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module): hidden_states, residual = self.norm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mixer(hidden_states, output) return output, residual @@ -137,7 +127,6 @@ class Mamba2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -152,25 +141,10 @@ class Mamba2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - for i, layer in enumerate(self.layers): - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=mamba_cache_params.at_layer_idx( - i - self.start_layer) if mamba_cache_params else None, - mamba2_metadata=mamba2_metadata) + hidden_states, residual = layer(positions=positions, + hidden_states=hidden_states, + residual=residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): head_dim=hf_config.head_dim, state_size=hf_config.state_size, conv_kernel=hf_config.conv_kernel, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = ( - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba)) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - else: - # NOTE: mamba_cache_params is not needed for v1 - mamba_cache_params = None - - hidden_states = self.backbone(input_ids, positions, mamba_cache_params, + hidden_states = self.backbone(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py deleted file mode 100644 index 6b16e3ce7d984..0000000000000 --- a/vllm/model_executor/models/mamba_cache.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass - -import torch - -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import VllmConfig -from vllm.model_executor.models.constant_size_cache import ConstantSizeCache - - -@dataclass -class MambaCacheParams: - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - state_indices_tensor: torch.Tensor = torch.Tensor() - - def at_layer_idx(self, layer_idx): - return MambaCacheParams(self.conv_state[layer_idx], - self.ssm_state[layer_idx], - self.state_indices_tensor) - - -class MambaCacheManager(ConstantSizeCache): - - def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int, - conv_state_shape: tuple[int, int], - temporal_state_shape: tuple[int, int], - conv_state_dtype: torch.dtype, - temporal_state_dtype: torch.dtype): - - self.conv_state_dtype = conv_state_dtype - self.temporal_state_dtype = temporal_state_dtype - - # Determine max batch size to set size of MambaCache - max_batch_size = vllm_config.scheduler_config.max_num_seqs - if not vllm_config.model_config.enforce_eager: - max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size) - - # Initialize parent class - super().__init__(max_batch_size) - - # assume conv_state = (dim, state_len) - assert conv_state_shape[0] > conv_state_shape[1] - conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + - (conv_state_shape[1], conv_state_shape[0]), - dtype=self.conv_state_dtype, - device="cuda").transpose(-1, -2) - temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=self.temporal_state_dtype, - device="cuda") - - self._mamba_cache = (conv_state, temporal_state) - - @property - def cache(self): - return self._mamba_cache - - def _copy_cache(self, from_index: int, to_index: int): - for cache_t in self.cache: - cache_t[:, to_index].copy_(cache_t[:, from_index], - non_blocking=True) - - def current_run_tensors(self, **kwargs) -> MambaCacheParams: - """ - Return the tensors for the current run's conv and ssm state. - """ - cache_tensors, state_indices_tensor = super().current_run_tensors( - **kwargs) - return MambaCacheParams(cache_tensors[0], cache_tensors[1], - state_indices_tensor) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph - replay runs. - """ - return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size, - dtype=torch.int32, - device="cuda") diff --git a/vllm/model_executor/models/minimax_cache.py b/vllm/model_executor/models/minimax_cache.py deleted file mode 100644 index 9164ac06a3b0a..0000000000000 --- a/vllm/model_executor/models/minimax_cache.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass - -import torch - -from vllm.model_executor.models.constant_size_cache import ConstantSizeCache - - -@dataclass -class MinimaxCacheParams: - minimax_cache: torch.Tensor = torch.Tensor() - state_indices_tensor: torch.Tensor = torch.Tensor() - - def at_layer_idx(self, layer_idx): - return MinimaxCacheParams(self.minimax_cache[layer_idx, ...], - self.state_indices_tensor) - - -class MinimaxCacheManager(ConstantSizeCache): - - def __init__(self, dtype, cache_shape): - super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1] - self._minimax_cache = torch.empty(size=cache_shape, - dtype=dtype, - device="cuda") - - @property - def cache(self): - return self._minimax_cache - - def _copy_cache(self, from_index: int, to_index: int): - assert len(self.cache) > 0 - for cache_t in self.cache: - cache_t[:, to_index].copy_(cache_t[:, from_index], - non_blocking=True) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 1d2c7dea811e0..cc9a959f63313 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -14,7 +14,6 @@ import torch.distributed from torch import nn from transformers import MiniMaxConfig -from vllm import envs from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid -from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module): def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: Union[list[dict], Optional[torch.Tensor]], attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], is_warmup: bool = False, @@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module): hidden_states=layernorm_output, output=self_attention_output, positions=positions, - kv_caches=kv_caches, ) residual = residual * self.layernorm_attention_alpha @@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module): self._dtype = _dummy.dtype del _dummy - if not envs.VLLM_USE_V1: - self.minimax_cache = MinimaxCacheManager( - dtype=torch.float32, cache_shape=self.cache_shape) - norm_kwargs = {} if hasattr(config, "rms_norm_eps"): norm_kwargs["eps"] = config.rms_norm_eps @@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module): **kwargs) -> Union[torch.Tensor, IntermediateTensors]: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - if not envs.VLLM_USE_V1 and attn_metadata is None: - return None - if not envs.VLLM_USE_V1: - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] - ( - minimax_cache_tensors, - state_indices_tensor, - ) = self.minimax_cache.current_run_tensors(**kwargs) - if getattr(attn_metadata, "num_prefills", 0) > 0: - self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, - **kwargs) - - minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, - state_indices_tensor) - else: - minimax_cache_params = None if get_pp_group().is_first_rank: if inputs_embeds is None: @@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - minimax_cache_index = 0 - for layer in islice(self.layers, self.start_layer, self.end_layer): - _caches = None - if not envs.VLLM_USE_V1 and isinstance( - layer.self_attn, MiniMaxText01LinearAttention): - current_state_layer = minimax_cache_index - _caches = minimax_cache_params.at_layer_idx( - current_state_layer) - minimax_cache_index += 1 hidden_states, residual = layer( hidden_states=hidden_states, positions=positions, - kv_caches=_caches, attn_metadata=attn_metadata, residual=residual, ) @@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, ...], ...]: """Calculate shape for MiniMaxText01LinearAttention cache. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index ff571541a60a5..987920ecc3310 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -23,21 +23,17 @@ from typing import Optional import torch from torch import nn -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -49,14 +45,11 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig -from vllm.utils import LayerBlockType class NemotronHMLP(nn.Module): @@ -181,8 +174,6 @@ class NemotronHMambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -192,7 +183,7 @@ class NemotronHMambaDecoderLayer(nn.Module): hidden_states, residual = self.norm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mixer(hidden_states, output) return output, residual @@ -370,22 +361,10 @@ class NemotronHModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -398,22 +377,11 @@ class NemotronHModel(nn.Module): residual = intermediate_tensors["residual"] residual = None - num_non_mamba_layers = 0 for i, layer in enumerate(self.layers): - layer_mamba_cache_params = None - if isinstance(layer, - NemotronHMambaDecoderLayer) and mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - i - num_non_mamba_layers) - else: - num_non_mamba_layers += 1 - hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: @@ -508,13 +476,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -533,7 +499,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, head_dim=hf_config.mamba_head_dim, state_size=hf_config.ssm_state_size, conv_kernel=hf_config.conv_kernel, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -566,8 +531,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, if not lora_config else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -584,40 +547,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - - num_mamba_layers = \ - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba - ) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py deleted file mode 100644 index ae153558e37aa..0000000000000 --- a/vllm/model_executor/models/phi4flash.py +++ /dev/null @@ -1,731 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math -from collections.abc import Iterable -from typing import Optional, Union - -import torch -import torch.nn as nn -from transformers.activations import ACT2FN - -import vllm.envs as envs -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.selector import _Backend -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsV0Only) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.sequence import IntermediateTensors - -from .utils import make_layers, maybe_prefix - -logger = init_logger(__name__) - - -class SwiGLUActivation(nn.Module): - - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - return x1 * nn.functional.silu(x2) - - -class SambaYMLP(nn.Module): - """Gated Linear Unit. - - Reference: - Language Modeling with Gated Convolutional Networks. - https://arxiv.org/pdf/1612.08083v3.pdf. - - """ - - def __init__(self, config): - super().__init__() - - self.config = config - self.fc1 = nn.Linear(config.hidden_size, - 2 * config.intermediate_size, - bias=False) - self.fc2 = nn.Linear(config.intermediate_size, - config.hidden_size, - bias=False) - - self.activation_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - y = self.fc1(hidden_states) - gate, y = y.chunk(2, dim=-1) - y = y * self.activation_fn(gate) - return self.fc2(y) - - -def get_virtual_engine(): - forward_context: ForwardContext = get_forward_context() - return forward_context.virtual_engine - - -class SambaYAttention(nn.Module): - - def __init__(self, - config, - layer_idx: Optional[int] = None, - yoco_cross: bool = False, - cache_config: Optional[CacheConfig] = None, - prefix: str = ""): - super().__init__() - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing " - "a `layer_idx` is not recommended and will lead to errors " - "during the forward call if caching is used. Please make " - "sure to provide a `layer_idx` when creating this class.") - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.yoco_cross = yoco_cross - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError("hidden_size must be divisible by num_heads " - f"(got `hidden_size`: {self.hidden_size} and " - f"`num_heads`: {self.num_heads}).") - - op_size = self.num_heads * self.head_dim + 2 * ( - self.num_key_value_heads * self.head_dim) - self.out_proj = nn.Linear(self.num_heads * self.head_dim, - self.hidden_size, - bias=True) - if yoco_cross: - self.Wqkv = nn.Linear(self.hidden_size, - self.num_heads * self.head_dim, - bias=True) - else: - self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) - - # disable sliding window for the second half of the model - is_sliding = config.layer_types[layer_idx] == "sliding_attention" - sliding_window = config.sliding_window if is_sliding else None - - assert self.num_heads % 2 == 0, 'num_heads should be even' - assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' - - self.lambda_init = self.lambda_init_fn(layer_idx) - self.lambda_q1 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.lambda_k1 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.lambda_q2 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.lambda_k2 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.subln = nn.RMSNorm(2 * self.head_dim, - eps=1e-5, - elementwise_affine=True) - - params = { - 'differential_flash_attention_config': { - 'lambda_init': self.lambda_init, - 'lambda_q1': self.lambda_q1, - 'lambda_k1': self.lambda_k1, - 'lambda_q2': self.lambda_q2, - 'lambda_k2': self.lambda_k2, - "subln": self.subln, - } - } - - if yoco_cross: - kv_shared_layer_index = config.num_hidden_layers // 2 + 1 - kv_sharing_target_layer_name = \ - f"model.layers.{kv_shared_layer_index}.self_attn.attn" - else: - kv_sharing_target_layer_name = None - - self.attn = Attention( - self.num_heads, - self.head_dim, - self.head_dim**-0.5, - num_kv_heads=self.num_key_value_heads, - cache_config=cache_config, - per_layer_sliding_window=sliding_window, - prefix=f"{prefix}.attn", - attn_type=AttentionType.DECODER, - kv_sharing_target_layer_name=kv_sharing_target_layer_name, - **params) - assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\ - "DIFFERENTIAL_FLASH_ATTN required" - - def lambda_init_fn(self, depth): - return 0.8 - 0.6 * math.exp(-0.3 * depth) - - def forward( - self, - hidden_states: torch.Tensor, - ): - - if not self.yoco_cross: # need to generate kv-cache - qkv = self.Wqkv(hidden_states) - q, k, v = qkv.split([ - self.hidden_size, self.num_key_value_heads * self.head_dim, - self.num_key_value_heads * self.head_dim - ], - dim=-1) - attn_output = self.attn(q, k, v) - else: # reuse the kv cache, full attention - q = self.Wqkv(hidden_states) - attn_output = self.attn(q, None, None) - attn_output = attn_output.view(-1, self.num_heads * self.head_dim) - return self.out_proj(attn_output) - - -class Phi4Mamba(nn.Module): - - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", # difference - dt_scale=1.0, # difference - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - yoco_cross=False, - yoco_kv=False, - ): - factory_kwargs = {"params_dtype": dtype} # difference - super().__init__() - self.yoco_cross = yoco_cross - self.yoco_kv = yoco_kv - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / - 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.swiGluActivation = SwiGLUActivation() - if self.yoco_cross: - self.in_proj = MergedColumnParallelLinear(self.d_model, - [self.d_inner], - bias=bias, - **factory_kwargs) - self.out_proj = RowParallelLinear(self.d_inner, - self.d_model, - bias=bias, - **factory_kwargs) - return - self.conv1d = ColumnParallelLinear( - input_size=d_conv, - output_size=self.d_inner, - bias=conv_bias, - params_dtype=dtype, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - self.in_proj = MergedColumnParallelLinear( - self.d_model, - [self.d_inner] * 2, - bias=bias, - params_dtype=dtype, - ) - - # selective projection used to make dt, B and C input dependent - self.x_proj = RowParallelLinear( - self.d_inner, - self.dt_rank + self.d_state * 2, - bias=False, - params_dtype=dtype, - ) - - # time step projection (discretization) - - # In the forward we need to apply dt_proj without the bias, - # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear( - self.dt_rank, - self.d_inner, - bias=True, - skip_bias_add=True, - params_dtype=dtype, - ) - - # # D "skip" parameter - # self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 - self.A = nn.Parameter( - torch.empty( - self.d_inner, - self.d_state, - dtype=torch.float32, - )) - self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) - - self.out_proj = RowParallelLinear( - self.d_inner, - self.d_model, - bias=bias, - input_is_parallel=True, - params_dtype=dtype, - ) - self.activation = "silu" - - def forward(self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - yoco_key_values=None) -> torch.Tensor: - - if self.yoco_cross: - out = self.in_proj(hidden_states)[0] - out = self.swiGluActivation(yoco_key_values, out) - out = self.out_proj(out) - return out[0], yoco_key_values - - # 1. Gated MLP's linear projection - # projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - projected_states = self.in_proj( - hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), - mamba_cache_params.conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.dt_rank, self.d_state, self.d_state], - dim=-1, - ) - - # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. - - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, - mamba_cache_params.ssm_state, - discrete_time_step, - self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - # z, - None if self.yoco_kv else gate, - time_proj_bias, - delta_softplus=True, - cache_indices=mamba_cache_params.state_indices_tensor, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) - else: - scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) - selective_state_update( - mamba_cache_params.ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, - B, - C, - self.D, - # z - # gate.transpose(0, 1), - None if self.yoco_kv else gate.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor, - out=scan_outputs) - scan_outputs = scan_outputs.transpose(0, 1) - - # 4. Final linear projection - if self.yoco_kv: - # gate = gate.transpose(-1,-2).contiguous() - yoco_key_values = scan_outputs.transpose(-2, -1) - scan_outputs = self.swiGluActivation(scan_outputs, gate) - - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] - - return contextualized_states, yoco_key_values - - -class SambaYDecoderLayer(nn.Module): - - def __init__( - self, - config, - layer_idx, - cache_config, - prefix: str = "", - ) -> None: - super().__init__() - - self.config = config - self.layer_idx = layer_idx - - self.mlp = SambaYMLP(config) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - - self.yoco_mb = False - self.yoco_cross = False - if layer_idx >= config.num_hidden_layers // 2: - self.yoco_mb = True - self.yoco_cross = (layer_idx - >= (config.num_hidden_layers // 2 + 2)) - self.use_mamba = config.mb_per_layer > 0 and \ - layer_idx % config.mb_per_layer == 0 - if self.use_mamba: - factory_kwargs = {"dtype": None} - self.attn = Phi4Mamba(config.hidden_size, - layer_idx=layer_idx, - yoco_cross=self.yoco_cross, - yoco_kv=self.yoco_mb, - **factory_kwargs) - else: - self.attn = SambaYAttention(config, - layer_idx=layer_idx, - yoco_cross=self.yoco_cross, - cache_config=cache_config, - prefix=f"{prefix}.self_attn") - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - positions: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - ssm_output: Optional[torch.LongTensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if self.use_mamba: - assert mamba_cache_params is not None - else: - assert mamba_cache_params is None - - residual = hidden_states - hidden_states = self.input_layernorm( - hidden_states.to(dtype=self.input_layernorm.weight.dtype)) - - if self.use_mamba: - attn_outputs, ssm_output = self.attn(hidden_states, - attn_metadata, - mamba_cache_params, - yoco_key_values=ssm_output) - residual = residual.to(torch.float32) - else: - attn_outputs = self.attn(hidden_states, ) - hidden_states = residual + attn_outputs - residual = hidden_states - hidden_states = self.post_attention_layernorm( - hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype)) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states, ssm_output - - -class SambaYModel(nn.Module): - - def __init__(self, - config, - cache_config=None, - quant_config=None, - lora_config=None, - prefix: str = "") -> None: - super().__init__() - self.config = config - self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - # Pipeline parallel is not supported since the second half of - # the layers share the kv cache. - if get_pp_group().world_size != 1: - raise ValueError("Pipeline Parallel not supported") - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: SambaYDecoderLayer(config, - int(prefix.split('.')[-1]), - cache_config, - prefix=prefix), - prefix=f"{prefix}.layers") - self.final_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - - mamba_state_idx = 0 - ssm_output = None - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - if i == self.config.num_hidden_layers // 2 + 2: - # profile run - kv_cache_idx = self.config.num_hidden_layers // 2 + 1 - cache_layer = self.layers[kv_cache_idx] - kv_cache = cache_layer.attn.attn.kv_cache - if kv_cache[0].numel() == 0: - break - - # Starting from this layer, we do not need to calculate - # the kv cache since we reuse the kv cache from last layer. - # If in prefill phase, we can prune> truncate - # the hidden state to save computation cost. - if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1: - selected_token_indices = torch.cumsum( - attn_metadata.seq_lens_tensor, dim=0) - 1 - hidden_states = hidden_states.index_select( - 0, selected_token_indices) - ssm_output = ssm_output.index_select( - 0, selected_token_indices) - - if layer.use_mamba: - if i < self.config.num_hidden_layers // 2 or \ - not layer.yoco_cross: - mamba_cache = mamba_cache_params.at_layer_idx( - mamba_state_idx) - mamba_state_idx += 1 - else: - mamba_cache = mamba_cache_params.at_layer_idx( - mamba_state_idx - 1) - - hidden_states, ssm_output = layer(hidden_states, - positions, - attn_metadata, - mamba_cache, - ssm_output=ssm_output) - else: - hidden_states, ssm_output = layer( - hidden_states, - positions, - attn_metadata, - None, # mamba_cache_params - ssm_output=ssm_output) - - hidden_states = self.final_layernorm( - hidden_states.to(dtype=self.final_layernorm.weight.dtype)) - return hidden_states - - -class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config - quant_config = vllm_config.quant_config - scheduler_config = vllm_config.scheduler_config - self.compilation_config = vllm_config.compilation_config - self.vllm_config = vllm_config - # Prefix caching and chunked prefill is not supported for this model. - assert not cache_config.enable_prefix_caching, \ - "Phi4flash currently does not support prefix caching" - assert not scheduler_config.chunked_prefill_enabled, \ - "Phi4Flash currently does not support prefix caching" - super().__init__() - self.config = config - self.model_config = vllm_config.model_config - self.scheduler_config = scheduler_config - self.model = SambaYModel(config, - cache_config=cache_config, - prefix=maybe_prefix(prefix, "model")) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - self.embedding_bias = None - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logits_as_input=False) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - if self.mamba_cache is None: - num_mamba_layers = self.config.num_hidden_layers \ - // 2 // self.config.mb_per_layer + 1 - self.mamba_cache = MambaCacheManager( - self.vllm_config, - num_mamba_layers, - *self._get_mamba_cache_shape(), - self.lm_head.weight.dtype, - self.lm_head.weight.dtype, - ) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - attn_metadata = get_forward_context().attn_metadata - # input_ids and hidden_states isn't a one-to-one mapping in prefill - # stage due to YOCO optimization. - hidden_states = self.model(input_ids, positions, attn_metadata, - mamba_cache_params, intermediate_tensors, - inputs_embeds) - return hidden_states - - def _get_mamba_cache_shape( - self - ) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - mamba_expand = self.config.mamba_expand # 2 - mamba_d_conv = self.config.mamba_d_conv # 4 - mamba_d_state = self.config.mamba_d_state # 16 - conv_state_shape = ( - mamba_expand * hidden_size // world_size, - mamba_d_conv - 1, - ) - temporal_state_shape = ( - mamba_expand * hidden_size // world_size, - mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: - processed_logits = self.logits_processor( - self.lm_head, - hidden_states, - self.embedding_bias, - ) - return processed_logits - - def load_weights( - self, - weights: Iterable[tuple[str, torch.Tensor]], - ): - weights = {name: weight for name, weight in weights} - adjusted_weights = {} - for name, weight in weights.items(): - if "A_log" in name: - name = name.replace("A_log", "A") - weight = -torch.exp(weight.float()) - if "inner_cross_attn." in name: - name = name.replace("inner_cross_attn.", "") - adjusted_weights[name] = weight - adjusted_weights["lm_head.weight"] = weights[ - "model.embed_tokens.weight"] - loaded_params: set[str] = set() - for name, param in self.named_parameters(): - weight = adjusted_weights.get(name) - if weight is not None and weight.shape != param.shape: - logger.warning("Shape mismatch: %s %s %s", name, weight.shape, - param.shape) - loaded_params.add(name) - missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, - strict=False) - assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" - assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" - return loaded_params diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 33ee1cf44afd1..0292f3bf8317d 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -12,7 +12,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile @@ -29,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata, update_metadata) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( @@ -47,15 +44,13 @@ from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsPP) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.model_executor.models.utils import ( is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType, direct_register_custom_op +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -194,17 +189,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp): self.chunk_size = self.config.mamba_chunk_size - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - # The inner tuple is (conv_state, ssm_state) - self.kv_cache = [(torch.tensor([]), torch.tensor([]))] - assert self.chunk_size != -1, "chunk_size must be set for v1" + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The tuple is (conv_state, ssm_state) + self.kv_cache = (torch.tensor([]), torch.tensor([])) + assert self.chunk_size != -1, "chunk_size must be set for v1" self.prefix = prefix @@ -227,8 +218,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): pass @@ -237,59 +226,43 @@ class Plamo2MambaMixer(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): - if not envs.VLLM_USE_V1: - CustomOp.forward(self, hidden_states, output, mamba_cache_params, - mamba2_metadata) - else: - torch.ops.vllm.plamo2_mamba_mixer( - hidden_states, - output, - self.prefix, - ) + torch.ops.vllm.plamo2_mamba_mixer( + hidden_states, + output, + self.prefix, + ) def forward_cuda( self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): forward_context = get_forward_context() - # mamba2_metadata contains metadata necessary for the mamba2 triton + # attn_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - mamba2_metadata = attn_metadata - assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - # conv_state = (..., dim, width-1) yet contiguous along 'dim' - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - state_indices_tensor = attn_metadata.state_indices_tensor - else: - conv_state = mamba_cache_params.conv_state - ssm_state = mamba_cache_params.ssm_state - state_indices_tensor = mamba_cache_params.state_indices_tensor - # Common members between V1 metadata and V0 metadata - if mamba2_metadata is not None: - has_initial_states_p = mamba2_metadata.has_initial_states_p - prep_initial_states = mamba2_metadata.prep_initial_states - chunk_size = mamba2_metadata.chunk_size - seq_idx_p = mamba2_metadata.seq_idx_p - chunk_indices_p = mamba2_metadata.chunk_indices_p - chunk_offsets_p = mamba2_metadata.chunk_offsets_p + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states_p + prep_initial_states = attn_metadata.prep_initial_states + chunk_size = attn_metadata.chunk_size + seq_idx_p = attn_metadata.seq_idx_p + chunk_indices_p = attn_metadata.chunk_indices_p + chunk_offsets_p = attn_metadata.chunk_offsets_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) @@ -299,8 +272,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp): conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if envs.VLLM_USE_V1 and attn_metadata is None: - # V1 profile run + if attn_metadata is None: + # profile run hidden_states = (hidden_states.transpose(0, 1).clone().transpose( 0, 1)).contiguous() output[:] = self.out_proj(hidden_states) @@ -316,42 +289,23 @@ class Plamo2MambaMixer(MambaBase, CustomOp): # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - if envs.VLLM_USE_V1: - hidden_states_d, hidden_states_p = torch.split( - hidden_states[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0, - ) - gate_d, gate_p = torch.split(gate[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0) - # Split along batch dimension - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor, - [num_decodes, num_prefills], - dim=0, - ) - query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1:] - - num_decodes if has_prefill else None) - else: - hidden_states_p, hidden_states_d = torch.split( - hidden_states, - [num_prefill_tokens, num_decodes], - dim=0, - ) - gate_p, gate_d = torch.split(gate, - [num_prefill_tokens, num_decodes], - dim=0) - # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - state_indices_tensor, - [num_prefills, num_decodes], - dim=0, - ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + - 1] - if has_prefill else None) + hidden_states_d, hidden_states_p = torch.split( + hidden_states[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + gate_d, gate_p = torch.split(gate[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -363,18 +317,11 @@ class Plamo2MambaMixer(MambaBase, CustomOp): dtype=hidden_states.dtype, device=hidden_states.device, ) - if envs.VLLM_USE_V1: - preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( - preallocated_ssm_out, - [num_decodes, num_prefill_tokens], - dim=0, - ) - else: - preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( - preallocated_ssm_out, - [num_prefill_tokens, num_decodes], - dim=0, - ) + preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( + preallocated_ssm_out, + [num_decodes, num_prefill_tokens], + dim=0, + ) # Process prefill requests if has_prefill: @@ -383,9 +330,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp): # pointed to by "state_indices_tensor" x = hidden_states_p.transpose( 0, 1) # this is the form that causal-conv see - if mamba2_metadata.cu_seqlen is None: - mamba2_metadata = update_metadata(x, query_start_loc_p, - mamba2_metadata) hidden_states_p = causal_conv1d_fn( x, conv_weights, @@ -394,7 +338,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp): conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - metadata=mamba2_metadata, + metadata=attn_metadata, query_start_loc=query_start_loc_p) hidden_states_p = hidden_states_p.transpose(0, 1) hidden_states_p = hidden_states_p[:num_prefill_tokens] @@ -470,7 +414,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp): -1, self.num_heads // self.tp_size, self.head_dim) # - the hidden is reshaped into (bs, num_heads, head_dim) - # - mamba_cache_params.ssm_state's slots will be selected + # - ssm_state's slots will be selected # using state_indices_tensor_d # NOTE: final output is an in-place update of out tensor @@ -530,10 +474,7 @@ def plamo2_mamba_mixer( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, - output=output, - mamba_cache_params=None, - mamba2_metadata=None) + self.forward_cuda(hidden_states=hidden_states, output=output) def plamo2_mamba_mixer_fake( @@ -731,8 +672,6 @@ class Plamo2DecoderLayer(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -747,8 +686,6 @@ class Plamo2DecoderLayer(nn.Module): output = torch.empty_like(hidden_states) mixer_kwargs = { "output": output, - "mamba_cache_params": mamba_cache_params, - "mamba2_metadata": mamba2_metadata, } else: mixer_kwargs = { @@ -790,23 +727,12 @@ class Plamo2Decoder(torch.nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: - mamba_cache_index = 0 for layer in islice(self.layers, self.start_layer, self.end_layer): - layer_mamba_cache_params = None - if layer.is_mamba and mamba_cache_params is not None: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - mamba_cache_index) - mamba_cache_index += 1 - hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) return hidden_states, residual @@ -844,7 +770,6 @@ class Plamo2Model(torch.nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -859,23 +784,10 @@ class Plamo2Model(torch.nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - if not envs.VLLM_USE_V1: - attn_metadata: AttentionMetadata = get_forward_context( - ).attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - hidden_states, residual = self.layers( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -925,9 +837,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): if self.config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.config.vocab_size) self.make_empty_intermediate_tensors = ( @@ -942,39 +851,11 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = ( - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba)) - mamba_state_shape = self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - else: - # NOTE: mamba_cache_params is not needed for v1 - mamba_cache_params = None - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - @classmethod def get_mamba_state_dtype_from_config( cls, @@ -991,12 +872,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: - conv_state_shape: Shape for convolutional state cache @@ -1015,7 +894,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): head_dim=hf_config.hidden_size_per_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def compute_logits( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 24cebc5bfdd82..ab23b494e561e 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -11,7 +11,6 @@ from einops import rearrange from torch import nn from transformers.activations import ACT2FN -from vllm import envs from vllm.attention import Attention, AttentionBackend, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, @@ -35,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata from vllm.model_executor.layers.mamba.mamba_mixer2 import ( mamba_v2_sharded_weight_loader) from vllm.model_executor.layers.mamba.mamba_utils import ( @@ -51,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -198,14 +195,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.gated_delta_net_state_shape( - self.tp_size, - self.num_k_heads, - self.num_v_heads, - self.head_k_dim, - self.head_v_dim, - self.conv_kernel_size, - self.num_spec, - use_v1=True) + self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, + self.head_v_dim, self.conv_kernel_size, self.num_spec) def __init__( self, @@ -394,7 +385,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): self, hidden_states: torch.Tensor, output: torch.Tensor, - cache_params: Optional[MambaCacheParams] = None, ): return torch.ops.vllm.gdn_attention( hidden_states, @@ -416,7 +406,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] - conv_metadata = attn_metadata assert isinstance(attn_metadata, GDNAttentionMetadata) has_initial_state = attn_metadata.has_initial_state spec_query_start_loc = attn_metadata.spec_query_start_loc @@ -479,12 +468,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): # 2.2: process the remaining part if attn_metadata.num_prefills > 0: mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) - if conv_metadata.cu_seqlen is None: - conv_metadata = update_metadata(mixed_qkv_non_spec_T, - non_spec_query_start_loc, - conv_metadata) # - "cache_indices" updates the conv_state cache in positions - # pointed to by "mamba_cache_params.state_indices_tensor" + # pointed to by "state_indices_tensor" mixed_qkv_non_spec = causal_conv1d_fn( mixed_qkv_non_spec_T, conv_weights, @@ -494,7 +479,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, - metadata=conv_metadata, + metadata=attn_metadata, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: mixed_qkv_non_spec = causal_conv1d_update( @@ -1075,7 +1060,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, \ "Qwen3Next currently does not support prefix caching" - assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1" self.quant_config = vllm_config.quant_config super().__init__() @@ -1195,14 +1179,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, num_spec = (vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0) return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, - hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, - hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, - hf_config.linear_conv_kernel_dim, - num_spec, - use_v1=True) + tp_size, hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim, + num_spec) def compute_logits( self, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 86123bc092b9b..6ab3fa902c387 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -134,7 +134,6 @@ _TEXT_GENERATION_MODELS = { "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 4350e38e02f96..a0d93045b74cf 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -15,12 +15,10 @@ import torch from torch import nn from transformers import Zamba2Config -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -29,8 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -39,8 +35,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid @@ -515,8 +509,6 @@ class Zamba2MambaDecoderLayer(nn.Module): def forward( self, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, transformer_hidden_states: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None, original_hidden_states: Optional[torch.Tensor] = None, @@ -525,8 +517,6 @@ class Zamba2MambaDecoderLayer(nn.Module): Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] - mamba_cache_params: Parameters for Mamba's state caches - (one for conv, one for ssm) transformer_hidden_states: Optional output from transformer path Added to input if provided (used in hybrid architecture) positions: Optional position IDs (unused in Mamba) @@ -555,8 +545,6 @@ class Zamba2MambaDecoderLayer(nn.Module): self.mamba( hidden_states, output, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) # residual connection after mamba @@ -607,8 +595,6 @@ class Zamba2HybridLayer(nn.Module): hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: """Forward pass through the hybrid layer. @@ -623,8 +609,6 @@ class Zamba2HybridLayer(nn.Module): original_hidden_states: Original input for transformer residual connection positions: Position IDs for positional embeddings - mamba_cache_params: Parameters for Mamba's state caches - (one for conv, one for ssm) Returns: Output tensor combining transformer and Mamba representations @@ -644,8 +628,6 @@ class Zamba2HybridLayer(nn.Module): layer_outputs = self.mamba_decoder( hidden_states, transformer_hidden_states=transformer_hidden_states, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) return layer_outputs @@ -752,7 +734,6 @@ class Zamba2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: """Forward pass through the model. @@ -760,8 +741,6 @@ class Zamba2Model(nn.Module): Args: input_ids: Input token IDs positions: Position IDs for embeddings - mamba_cache_params: Parameters for Mamba's state caches - (one for conv, one for ssm) inputs_embeds: Optional pre-computed input embeddings Returns: @@ -773,33 +752,13 @@ class Zamba2Model(nn.Module): inputs_embeds = self.get_input_embeddings(input_ids) hidden_states = inputs_embeds - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - # Process through layers original_hidden_states = torch.clone(hidden_states) for layer_idx, layer in enumerate(self.layers): - - layer_mamba_cache_params = None - if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer)) - and mamba_cache_params): - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - layer_idx) - layer_outputs = layer( hidden_states, original_hidden_states=original_hidden_states, positions=positions, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) hidden_states = layer_outputs @@ -870,13 +829,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -896,7 +853,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): head_dim=hf_config.mamba_headdim, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -945,9 +901,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): # Tie weights with input embeddings if using same dimensions self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - # Initialize logits processing and sampling self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -977,61 +930,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): Returns: Output hidden states """ - # Initialize Mamba cache if needed - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = self.config.num_hidden_layers - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - # Get cache parameters for current run - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - # Forward pass through model hidden_states = self.model( input_ids, positions, - mamba_cache_params, inputs_embeds, ) return hidden_states - def copy_inputs_before_cuda_graphs( - self, input_buffers: dict[str, torch.Tensor], - **kwargs: Any) -> dict[str, torch.Tensor]: - """Copy inputs before CUDA graph capture. - - Args: - input_buffers: Dictionary of input tensors - **kwargs: Additional arguments passed to cache manager - - Returns: - Updated input buffers - """ - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs( - self, batch_size: int) -> dict[str, torch.Tensor]: - """Get inputs for sequence-length-agnostic graph capture. - - Args: - batch_size: Size of batch to capture - Returns: - Dictionary of capture inputs - """ - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 5dadc52d0fb1c..06a87a4a3c8b2 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -12,6 +12,7 @@ from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + compute_causal_conv1d_metadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -52,7 +53,6 @@ class GDNAttentionMetadata: # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None batch_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None @@ -134,6 +134,7 @@ class GDNAttentionMetadataBuilder( context_lens = m.num_computed_tokens_cpu context_lens_tensor = context_lens.to(query_start_loc.device) seq_lens_tensor = m.seq_lens + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None if (not self.use_spec_decode or num_draft_tokens is None or num_draft_tokens.sum().item() == 0): @@ -210,6 +211,8 @@ class GDNAttentionMetadataBuilder( has_initial_state = context_lens_tensor > 0 if spec_sequence_masks is not None: has_initial_state = has_initial_state[~spec_sequence_masks] + nums_dict, batch_ptr, token_chunk_offset_ptr = \ + compute_causal_conv1d_metadata(non_spec_query_start_loc) else: has_initial_state = None num_actual_tokens = num_prefill_tokens + num_decode_tokens + \ @@ -297,6 +300,9 @@ class GDNAttentionMetadataBuilder( spec_sequence_masks=spec_sequence_masks, spec_token_masks=spec_token_masks, num_accepted_tokens=num_accepted_tokens, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, ) return attn_metadata diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 2fe1f14ca1db0..f45fc75334a21 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -7,11 +7,12 @@ from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig from vllm.v1.attention.backends.mamba_attn import ( BaseMambaAttentionMetadataBuilder) -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, +from vllm.v1.attention.backends.utils import (PAD_SLOT_ID, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec @@ -131,7 +132,6 @@ class Mamba2AttentionMetadata: # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None batch_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None @@ -161,6 +161,9 @@ class Mamba2AttentionMetadataBuilder( has_initial_states_p = None prep_initial_states = False + # for causal_conv1d + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( @@ -198,6 +201,9 @@ class Mamba2AttentionMetadataBuilder( query_start_loc_p, self.chunk_size, num_prefill_tokens)) + nums_dict, batch_ptr, token_chunk_offset_ptr = \ + compute_causal_conv1d_metadata(query_start_loc_p) + elif num_decodes <= self.decode_cudagraph_max_bs: # Pad state tensor for CUDA graph num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) @@ -220,5 +226,8 @@ class Mamba2AttentionMetadataBuilder( chunk_indices_p=chunk_indices_p, chunk_offsets_p=chunk_offsets_p, state_indices_tensor=state_indices_tensor, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, ) return attn_metadata diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index 717c40b37ecfb..428e409659798 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, + compute_causal_conv1d_metadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -33,7 +34,6 @@ class ShortConvAttentionMetadata: # For causal_conv1d nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None batch_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None @@ -57,6 +57,9 @@ class ShortConvAttentionMetadataBuilder( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + # for causal_conv1d + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, @@ -70,6 +73,12 @@ class ShortConvAttentionMetadataBuilder( has_initial_states = has_initial_states_cpu.to( query_start_loc.device) + query_start_loc_p = common_attn_metadata.query_start_loc[ + -num_prefills - 1:] - num_decode_tokens + + nums_dict, batch_ptr, token_chunk_offset_ptr = \ + compute_causal_conv1d_metadata(query_start_loc_p) + attn_metadata = ShortConvAttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, @@ -78,5 +87,8 @@ class ShortConvAttentionMetadataBuilder( query_start_loc=query_start_loc, has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, ) return attn_metadata diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 63326d19194f0..6ef489f5a7a28 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -34,6 +34,8 @@ logger = init_logger(__name__) KVCacheLayoutType = Literal["NHD", "HND"] _KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None +PAD_SLOT_ID = -1 + def is_valid_kv_cache_layout(value: str) -> bool: return value in get_args(KVCacheLayoutType) @@ -838,3 +840,52 @@ def create_fast_prefill_custom_backend( builder_cls=FastPrefillAttentionBuilder) return attn_backend + + +def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): + + # Needed for causal_conv1d + seqlens = query_start_loc_p.diff().to('cpu') + nums_dict = {} # type: ignore + batch_ptr = None + token_chunk_offset_ptr = None + for BLOCK_M in [8]: # cover all BLOCK_M values + nums = -(-seqlens // BLOCK_M) + nums_dict[BLOCK_M] = {} + nums_dict[BLOCK_M]['nums'] = nums + nums_dict[BLOCK_M]['tot'] = nums.sum().item() + mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) + nums_dict[BLOCK_M]['mlist'] = mlist + mlist_len = len(nums_dict[BLOCK_M]['mlist']) + nums_dict[BLOCK_M]['mlist_len'] = mlist_len + MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 + offsetlist = [] # type: ignore + for idx, num in enumerate(nums): + offsetlist.extend(range(num)) + offsetlist = torch.tensor(offsetlist, dtype=torch.int32) + nums_dict[BLOCK_M]['offsetlist'] = offsetlist + + if batch_ptr is None: + # Update default value after class definition + batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device='cuda') + token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device='cuda') + else: + if batch_ptr.nelement() < MAX_NUM_PROGRAMS: + batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) + token_chunk_offset_ptr.resize_( # type: ignore + MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) + + batch_ptr[0:mlist_len].copy_(mlist) + token_chunk_offset_ptr[ # type: ignore + 0:mlist_len].copy_(offsetlist) + nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr + nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr + ) # type: ignore + + return nums_dict, batch_ptr, token_chunk_offset_ptr