[V1] Remove V0 code paths for Hybrid models (#25400)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Thomas Parnell 2025-09-23 17:26:13 +02:00 committed by yewentao256
parent 02134245a9
commit f97da2c732
31 changed files with 352 additions and 2296 deletions

View File

@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
SSM_MODELS = [ SSM_MODELS = [
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev", "tiiuae/falcon-mamba-tiny-dev",
"yujiepan/mamba2-codestral-v0.1-tiny-random", # mamba2-codestral in transformers is broken pending:
# https://github.com/huggingface/transformers/pull/40861
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
] ]
HYBRID_MODELS = [ HYBRID_MODELS = [
@ -31,18 +33,7 @@ HYBRID_MODELS = [
"ibm-granite/granite-4.0-tiny-preview", "ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base", "tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B", "LiquidAI/LFM2-1.2B",
] "tiny-random/qwen3-next-moe",
V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
"yujiepan/mamba2-codestral-v0.1-tiny-random",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
] ]
FULL_CUDA_GRAPH_MODELS = [ FULL_CUDA_GRAPH_MODELS = [
@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
"Zyphra/Zamba2-1.2B-instruct", "Zyphra/Zamba2-1.2B-instruct",
] ]
V0_UNSUPPORTED_MODELS = [
"LiquidAI/LFM2-1.2B",
]
FP32_STATE_MODELS = [ FP32_STATE_MODELS = [
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"Zyphra/Zamba2-1.2B-instruct", "Zyphra/Zamba2-1.2B-instruct",
@ -88,19 +75,15 @@ def test_models(
hf_outputs = hf_model.generate_greedy_logprobs_limit( hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
else:
vllm_v1_outputs = None
if model in V1_SUPPORTED_MODELS:
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm-v1", name_1="vllm",
) )
@ -299,14 +282,14 @@ def test_full_cuda_graph(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm-v1", name_1="vllm",
) )
@ -340,12 +323,12 @@ def test_fp32_cache_state(
with vllm_runner(model, with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS, max_num_seqs=MAX_NUM_SEQS,
**{cache_dtype_param: "float32"}) as vllm_model: **{cache_dtype_param: "float32"}) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm-v1", name_1="vllm",
) )

View File

@ -312,13 +312,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
trust_remote_code=True,
v0_only=True,
max_model_len=10240),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True), trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
max_transformers_version="4.55.4",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
max_transformers_version="4.53", max_transformers_version="4.53",
@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"), extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501
min_transformers_version="4.56.3"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True, trust_remote_code=True,
@ -644,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"), speculative_model="XiaomiMiMo/MiMo-7B-RL"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"), min_transformers_version="4.56.3"),
} }
_TRANSFORMERS_BACKEND_MODELS = { _TRANSFORMERS_BACKEND_MODELS = {

View File

@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase):
# Contains the KV cache (mamba state) for the layer # Contains the KV cache (mamba state) for the layer
# in the shape specified by `self.get_state_shape`. # in the shape specified by `self.get_state_shape`.
# The outer list is for v0 PP virtual engine. Though this code path kv_cache: tuple[torch.Tensor, ...]
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
kv_cache: list[Iterable[torch.Tensor]]
@abstractmethod @abstractmethod
def get_state_shape(self) -> Iterable[tuple[int, ...]]: def get_state_shape(self) -> Iterable[tuple[int, ...]]:

View File

@ -15,7 +15,6 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from vllm import envs
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
@ -42,8 +41,6 @@ if TYPE_CHECKING:
import torch import torch
import torch.distributed import torch.distributed
from vllm.model_executor.models.minimax_cache import MinimaxCacheParams
class MiniMaxText01RMSNormTP(CustomOp): class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP" name = "MiniMaxText01RMSNormTP"
@ -225,7 +222,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self.tp_heads:(self.tp_rank + 1) * self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous() self.tp_heads].contiguous()
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
break break
if _prefill_idx >= len(state_indices_tensor): if _prefill_idx >= len(state_indices_tensor):
break break
# prefills are packed at end of batch in V1 offset = attn_metadata.num_decode_tokens
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
_start = attn_metadata.query_start_loc[offset + _prefill_idx] _start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx] slot_id = state_indices_tensor[offset + _prefill_idx]
@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
hidden_decode = self._decode_infer(q, k, v, kv_cache, hidden_decode = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor, state_indices_tensor,
attn_metadata) attn_metadata)
if envs.VLLM_USE_V1:
hidden.insert(0, hidden_decode) hidden.insert(0, hidden_decode)
else:
hidden.append(hidden_decode)
if not hidden: if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
@ -304,13 +296,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata): attn_metadata):
if not envs.VLLM_USE_V1:
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
num_prefills = getattr(attn_metadata, "num_prefills", 0)
slot_id = state_indices_tensor[num_prefills:]
else:
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
@ -320,11 +305,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
return hidden return hidden
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor) -> None:
kv_caches: MinimaxCacheParams) -> None:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention( torch.ops.vllm.linear_attention(
hidden_states, hidden_states,
output, output,
@ -333,11 +314,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
) )
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor) -> None:
kv_caches: Optional[MinimaxCacheParams]) -> None:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1 and attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata) assert isinstance(attn_metadata, LinearAttentionMetadata)
@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact = torch.nn.functional.silu(qkv32) qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if envs.VLLM_USE_V1:
if attn_metadata is not None: if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0] kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0) num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0: if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata, num_decode_tokens = getattr(attn_metadata, "num_decode_tokens",
"num_decode_tokens", 0) 0)
for prefill_idx in range(num_prefills): for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[ q_start = attn_metadata.query_start_loc[num_decode_tokens +
num_decode_tokens + prefill_idx] prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens q_end = attn_metadata.query_start_loc[num_decode_tokens +
+ prefill_idx + prefill_idx + 1]
1]
query_len = q_end - q_start query_len = q_end - q_start
context_len = attn_metadata.seq_lens[ context_len = attn_metadata.seq_lens[
num_decode_tokens + prefill_idx] - query_len num_decode_tokens + prefill_idx] - query_len
if context_len == 0: if context_len == 0:
block_to_clear = state_indices_tensor[ block_to_clear = state_indices_tensor[num_decode_tokens
num_decode_tokens + prefill_idx] + prefill_idx]
kv_cache[block_to_clear, ...] = 0 kv_cache[block_to_clear, ...] = 0
else:
assert kv_caches is not None
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None: if attn_metadata is None:
@ -410,8 +384,7 @@ def linear_attention(
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, self._forward(hidden_states=hidden_states,
output=output, output=output,
positions=positions, positions=positions)
kv_caches=None)
def linear_attention_fake( def linear_attention_fake(

View File

@ -1,177 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.platforms import current_platform
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
@dataclass
class Mamba2Metadata:
prep_initial_states: bool
chunk_size: int
has_initial_states_p: torch.Tensor
seq_idx_p: torch.Tensor
chunk_indices_p: torch.Tensor
chunk_offsets_p: torch.Tensor
"""
With continuous batching layout of `x` in vLLM, to enable a Triton program
to handle a request in parallel, two supporting tensors are used
(batch_ptr, token_chunk_offset_ptr)
BLOCK_M = the # tokens to be handled by a Triton program
(can be customized for different hardware)
nums_dict:
tracks the data associated with a given value of BLOCK_M
BLOCK_M = #tokens handled by a Triton program
cu_seqlen: total tokens per batch
(used as flag to update other data at each new input)
batch_ptr: tracks batch-id handled by the Triton program
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
(Triton implementation of causal_conv1d handles parallelism in 3-axes
- feature-axis
- batch-axis
- sequence-axis)
"""
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
"""Returns the appropriate metadata classes for the current platform."""
if current_platform.is_rocm():
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.triton_attn import (
TritonAttentionMetadata)
return (AiterFlashAttentionMetadata, TritonAttentionMetadata,
PlaceholderAttentionMetadata)
if current_platform.is_cuda():
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata)
from vllm.v1.attention.backends.xformers import (
XFormersAttentionMetadata)
return (FlashAttentionMetadata, XFormersAttentionMetadata,
PlaceholderAttentionMetadata)
raise ValueError(
f"Unsupported platform for Mamba2: {current_platform.device_type}")
def prepare_mamba2_metadata(
chunk_size: int,
attn_metadata: AttentionMetadata,
) -> Mamba2Metadata:
# compute number of prefill and decode requests
# NOTE: in V0 we assume prefills are before decodes
num_prefills = attn_metadata.num_prefills
num_prefill_tokens = attn_metadata.num_prefill_tokens
seq_idx_p = None
chunk_indices_p, chunk_offsets_p = None, None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states_p = None
prep_initial_states = False
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:
attn_metadata_instances = get_platform_metadata_classes()
if (isinstance(attn_metadata, attn_metadata_instances)
and attn_metadata.context_lens_tensor is not None):
# precompute flag to avoid device syncs later in mamba2 layer
# forwards
# prep is only needed for mamba2 ssd prefill processing
has_initial_states_p = (
attn_metadata.context_lens_tensor[:num_prefills] > 0)
prep_initial_states = torch.any(has_initial_states_p).item()
query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx_p = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=num_prefill_tokens)
seq_idx_p.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
if prep_initial_states:
chunk_indices_p, chunk_offsets_p = \
_query_start_loc_to_chunk_indices_offsets(
query_start_loc_p, chunk_size, num_prefill_tokens)
return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx_p=seq_idx_p,
chunk_indices_p=chunk_indices_p,
chunk_offsets_p=chunk_offsets_p)
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
mamba2_metadata: Union[Mamba2Metadata,
Mamba2AttentionMetadata,
GDNAttentionMetadata]):
"""
this is triggered upon handling a new input at the first layer
"""
dim, cu_seqlen = x.shape
mamba2_metadata.cu_seqlen = cu_seqlen
seqlens = np.diff(query_start_loc.to('cpu'))
nums_dict = {} # type: ignore
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if mamba2_metadata.batch_ptr is None:
# Update default value after class definition
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
mamba2_metadata.token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
else:
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
PAD_SLOT_ID)
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
mamba2_metadata.nums_dict = nums_dict
return mamba2_metadata

View File

@ -10,8 +10,6 @@ import torch
from torch import nn from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update) selective_scan_fn, selective_state_update)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp):
has_weight=rms_norm_has_weight, has_weight=rms_norm_has_weight,
) if use_rms_norm else None ) if use_rms_norm else None
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state) # The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))] self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp):
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
return discrete_time_step, B, C return discrete_time_step, B, C
def forward(self, def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
else:
torch.ops.vllm.mamba_mixer( torch.ops.vllm.mamba_mixer(
hidden_states, hidden_states,
output, output,
self.prefix, self.prefix,
) )
def forward_native(self, def forward_native(self, hidden_states: torch.Tensor,
hidden_states: torch.Tensor, output: torch.Tensor):
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
pass pass
def forward_cuda(self, def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
""" """
Run the Mamba-1 SSM pipeline. Run the Mamba-1 SSM pipeline.
@ -234,7 +216,6 @@ class MambaMixer(MambaBase, CustomOp):
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
@ -247,18 +228,6 @@ class MambaMixer(MambaBase, CustomOp):
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
has_initial_states = mamba1_metadata.has_initial_states has_initial_states = mamba1_metadata.has_initial_states
num_padded_decodes = mamba1_metadata.num_padded_decodes num_padded_decodes = mamba1_metadata.num_padded_decodes
else:
assert isinstance(attn_metadata, AttentionMetadata)
assert mamba_cache_params is not None
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
query_start_loc = attn_metadata.query_start_loc
context_lens_tensor = attn_metadata.context_lens_tensor
has_initial_states = None
if context_lens_tensor is not None:
has_initial_states = context_lens_tensor > 0
num_padded_decodes = attn_metadata.num_decode_tokens
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2)) self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None: if attn_metadata is None:
# V1 profile run # V1 profile run
hidden_states_BC = hidden_states_BC.contiguous() hidden_states_BC = hidden_states_BC.contiguous()
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0] return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp):
out=scan_outputs_d) out=scan_outputs_d)
scan_outputs_d = scan_outputs_d.transpose(0, 1) scan_outputs_d = scan_outputs_d.transpose(0, 1)
if envs.VLLM_USE_V1:
ssm_outputs.insert(0, scan_outputs_d) ssm_outputs.insert(0, scan_outputs_d)
else:
ssm_outputs.append(scan_outputs_d)
scan_outputs_combined = ssm_outputs[0] if len( scan_outputs_combined = ssm_outputs[0] if len(
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
@ -441,9 +407,9 @@ def split_batch_to_prefill_and_decode(
num_decodes: int, num_decodes: int,
num_padded_decodes: int, num_padded_decodes: int,
) -> PrefillDecodeSplit: ) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes num_actual_tokens = num_prefill_tokens + num_padded_decodes
if envs.VLLM_USE_V1:
# In v1, decode tokens come first, then prefill tokens. # In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split( hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC[..., :num_actual_tokens], hidden_states_BC[..., :num_actual_tokens],
@ -462,19 +428,6 @@ def split_batch_to_prefill_and_decode(
num_padded_decodes if num_prefills > 0 else None) num_padded_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if ( has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None has_initial_states is not None and num_prefills > 0) else None
else:
# In v0, prefill tokens come first, then decode tokens.
hidden_states_BC_p, hidden_states_BC_d = torch.split(
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decode_tokens],
dim=-1)
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor, [num_prefills, num_decodes], dim=0)
query_start_loc_p = (query_start_loc[:num_prefills +
1] if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[:num_prefills] if (
has_initial_states is not None and num_prefills > 0) else None
return PrefillDecodeSplit( return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p, hidden_states_BC_p=hidden_states_BC_p,
@ -495,9 +448,7 @@ def mamba_mixer(
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, self.forward_cuda(hidden_states=hidden_states, output=output)
output=output,
mamba_cache_params=None)
def mamba_mixer_fake( def mamba_mixer_fake(

View File

@ -9,7 +9,6 @@ if TYPE_CHECKING:
import torch import torch
from torch import nn from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader) LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp):
self.use_rms_norm, self.use_rms_norm,
eps=rms_norm_eps) eps=rms_norm_eps)
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path # The tuple is (conv_state, ssm_state)
# only runs for v1, we have to do this to unify with the interface self.kv_cache = (torch.tensor([]), torch.tensor([]))
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None, mup_vector: Optional[torch.Tensor] = None,
): ):
pass pass
@ -478,14 +468,8 @@ class MambaMixer2(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None, mup_vector: Optional[torch.Tensor] = None,
): ):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata, mup_vector)
else:
torch.ops.vllm.mamba_mixer2( torch.ops.vllm.mamba_mixer2(
hidden_states, hidden_states,
output, output,
@ -497,40 +481,30 @@ class MambaMixer2(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None, mup_vector: Optional[torch.Tensor] = None,
): ):
forward_context = get_forward_context() forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton # attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they # modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration # stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
else: has_initial_states_p = attn_metadata.has_initial_states_p
conv_state = mamba_cache_params.conv_state prep_initial_states = attn_metadata.prep_initial_states
ssm_state = mamba_cache_params.ssm_state chunk_size = attn_metadata.chunk_size
state_indices_tensor = mamba_cache_params.state_indices_tensor seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
# Common members between V1 metadata and V0 metadata chunk_offsets_p = attn_metadata.chunk_offsets_p
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states) projected_states, _ = self.in_proj(hidden_states)
@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1, dim=-1,
) )
if envs.VLLM_USE_V1 and attn_metadata is None: if attn_metadata is None:
# V1 profile run # profile run
hidden_states_B_C = (hidden_states_B_C.transpose( hidden_states_B_C = (hidden_states_B_C.transpose(
0, 1).clone().transpose(0, 1)).contiguous() 0, 1).clone().transpose(0, 1)).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn( hidden_states, _B, _C = split_hidden_states_B_C_fn(
@ -579,10 +553,8 @@ class MambaMixer2(MambaBase, CustomOp):
has_decode = num_decodes > 0 has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes num_actual_tokens = num_prefill_tokens + num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input # Separate prefill and decode by splitting varlen input
# Split along token dimension # Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split( hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens], hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens], [num_decodes, num_prefill_tokens],
@ -602,26 +574,6 @@ class MambaMixer2(MambaBase, CustomOp):
query_start_loc_p = ( query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] - attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None) num_decodes if has_prefill else None)
else:
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill # Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs # and decode outputs
@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device, device=hidden_states.device,
) )
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out, preallocated_ssm_out,
[num_decodes, num_prefill_tokens], [num_decodes, num_prefill_tokens],
dim=0, dim=0,
) )
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests # Process prefill requests
if has_prefill: if has_prefill:
@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp):
# pointed to by "state_indices_tensor" # pointed to by "state_indices_tensor"
x = hidden_states_B_C_p.transpose( x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see 0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_B_C_p = causal_conv1d_fn( hidden_states_B_C_p = causal_conv1d_fn(
x, x,
conv_weights, conv_weights,
@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states=conv_state, conv_states=conv_state,
has_initial_state=has_initial_states_p, has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p, cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata, metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose( query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens] 0, 1)[:num_prefill_tokens]
@ -806,8 +748,6 @@ def mamba_mixer2(
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, self.forward_cuda(hidden_states=hidden_states,
output=output, output=output,
mamba_cache_params=None,
mamba2_metadata=None,
mup_vector=mup_vector) mup_vector=mup_vector)

View File

@ -100,7 +100,6 @@ class MambaStateShapeCalculator:
intermediate_size: int, intermediate_size: int,
state_size: int, state_size: int,
conv_kernel: int, conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int]]: ) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size, conv_state_shape = (divide(intermediate_size,
tp_world_size), conv_kernel - 1) tp_world_size), conv_kernel - 1)
@ -108,10 +107,6 @@ class MambaStateShapeCalculator:
temporal_state_shape = (divide(intermediate_size, temporal_state_shape = (divide(intermediate_size,
tp_world_size), state_size) tp_world_size), state_size)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0] conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape return conv_state_shape, temporal_state_shape
@ -126,7 +121,6 @@ class MambaStateShapeCalculator:
head_dim: int, head_dim: int,
state_size: int, state_size: int,
conv_kernel: int, conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards # if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it # to ensure all groups needed by a head is sharded along with it
@ -137,8 +131,6 @@ class MambaStateShapeCalculator:
# contiguous along 'dim' axis # contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
# These are not TP-ed as they depend on A, dt_bias, D # These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small # - they are typically small
@ -153,12 +145,9 @@ class MambaStateShapeCalculator:
tp_world_size: int, tp_world_size: int,
intermediate_size: int, intermediate_size: int,
conv_kernel: int, conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int]]: ) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size) conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim) conv_state_shape = (conv_kernel - 1, conv_dim)
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return (conv_state_shape, ) return (conv_state_shape, )
@classmethod @classmethod
@ -183,7 +172,6 @@ class MambaStateShapeCalculator:
head_v_dim: int, head_v_dim: int,
conv_kernel_size: int, conv_kernel_size: int,
num_spec: int = 0, num_spec: int = 0,
use_v1: bool = True,
): ):
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads) conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
conv_state_shape = ( conv_state_shape = (
@ -191,10 +179,6 @@ class MambaStateShapeCalculator:
conv_kernel_size - 1 + num_spec, conv_kernel_size - 1 + num_spec,
) )
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0] conv_state_shape = conv_state_shape[1], conv_state_shape[0]
temporal_state_shape = (divide(num_v_heads, temporal_state_shape = (divide(num_v_heads,

View File

@ -420,9 +420,7 @@ def causal_conv1d_fn(
x = x.to(conv_states.dtype) x = x.to(conv_states.dtype)
out = torch.empty_like(x) out = torch.empty_like(x)
if metadata is not None: if metadata is not None:
cu_seqlen = metadata.cu_seqlen
nums_dict = metadata.nums_dict nums_dict = metadata.nums_dict
#x = metadata.x
args = nums_dict args = nums_dict
batch_ptr = metadata.batch_ptr batch_ptr = metadata.batch_ptr
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
@ -926,7 +924,6 @@ def causal_conv1d_update(
query_start_loc: Optional[torch.Tensor] = None, query_start_loc: Optional[torch.Tensor] = None,
max_query_len: int = -1, max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID, pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False, validate_data=False,
): ):
""" """

View File

@ -8,7 +8,6 @@ if TYPE_CHECKING:
import torch import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp):
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
) )
assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1")
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path self.kv_cache = (torch.tensor([]), )
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
self.kv_cache = [(torch.tensor([]), )]
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
): ):
return return
@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
): ):
torch.ops.vllm.short_conv( torch.ops.vllm.short_conv(
hidden_states, hidden_states,
@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
): ):
forward_context = get_forward_context() forward_context = get_forward_context()
# ShortConvAttentionMetadata contains metadata necessary for the # ShortConvAttentionMetadata contains metadata necessary for the
@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp):
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, ShortConvAttentionMetadata) assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp):
if has_prefill: if has_prefill:
Bx_p = (B_p * x_p).transpose(0, 1) Bx_p = (B_p * x_p).transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(Bx_p, query_start_loc_p,
conv_metadata)
Bx = causal_conv1d_fn(Bx_p, Bx = causal_conv1d_fn(Bx_p,
conv_weights, conv_weights,
self.conv.bias, self.conv.bias,
@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp):
conv_states=conv_state, conv_states=conv_state,
has_initial_state=has_initial_states_p, has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p, cache_indices=state_indices_tensor_p,
metadata=conv_metadata, metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose( query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens] 0, 1)[:num_prefill_tokens]
@ -248,9 +235,7 @@ def short_conv(
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, self.forward_cuda(hidden_states=hidden_states, output=output)
output=output,
conv_metadata=None)
def short_conv_fake( def short_conv_fake(

View File

@ -9,21 +9,17 @@ import torch
from torch import nn from torch import nn
from transformers import BambaConfig from transformers import BambaConfig
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant) SupportsQuant)
@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states, residual) hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) self.mamba(hidden_states, output)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states) hidden_states = self.feed_forward(hidden_states)
@ -315,22 +306,10 @@ class BambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -343,23 +322,11 @@ class BambaModel(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
residual = None residual = None
num_attn = 0
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if isinstance(layer, BambaAttentionDecoderLayer):
num_attn += 1
layer_mamba_cache_params = None
if isinstance(layer,
BambaMixerDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim=hf_config.mamba_d_head, head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
mamba_cache_params = None hidden_states = self.model(input_ids, positions, intermediate_tensors,
if not envs.VLLM_USE_V1: inputs_embeds)
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@ -1,137 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
class ConstantSizeCache(ABC):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self.cache_indices_mapping: dict[str, dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
@abstractmethod
def cache(self) -> Any:
"""Return the underlying cache tensor(s)"""
pass
@abstractmethod
def _copy_cache(self, from_index: int, to_index: int):
"""Copy cache data from one index to another"""
pass
def current_run_tensors(self, **kwargs) -> tuple:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
cache_tensors = self.cache
else:
# CUDA graph capturing runs
cache_tensors, state_indices_tensor = kwargs[
"seqlen_agnostic_capture_inputs"]
return (cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.cache, state_indices_tensor)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_cache(from_index=index_exists,
to_index=destination_index)
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
return destination_index
else:
return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache(
self, request_ids_to_seq_ids: dict[str, list[int]],
finished_requests_ids: list[str]) -> list[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: list[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.cache_indices_mapping[req_id][seq_id])
self.cache_indices_mapping.pop(req_id)

View File

@ -8,21 +8,17 @@ import torch
from torch import nn from torch import nn
from transformers import FalconH1Config from transformers import FalconH1Config
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mamba( self.mamba(
hidden_states, hidden_states,
output, output,
mamba_cache_params,
mamba2_metadata=mamba2_metadata,
mup_vector=self.mup_vector, mup_vector=self.mup_vector,
) )
return output, residual return output, residual
@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
residual = hidden_states residual = hidden_states
@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module):
# Process input through the SSM branch. # Process input through the SSM branch.
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata, # FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
# residual, mamba_cache_params, and sequence_idx. # residual, and sequence_idx.
ssm_hidden, _ = self.mamba( ssm_hidden, _ = self.mamba(
hidden_states=hidden_states * self.ssm_in_multiplier, hidden_states=hidden_states * self.ssm_in_multiplier,
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
**kwargs, **kwargs,
) )
# Sum the outputs from both branches. # Sum the outputs from both branches.
@ -464,25 +450,10 @@ class FalconH1Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds * self.embedding_multiplier hidden_states = inputs_embeds * self.embedding_multiplier
@ -495,14 +466,9 @@ class FalconH1Model(nn.Module):
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
layer_mamba_cache_params = None
if mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
hidden_states = layer( hidden_states = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim=hf_config.mamba_d_head, head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.tie_word_embeddings = config.tie_word_embeddings self.tie_word_embeddings = config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.mamba_cache: Optional[MambaCacheManager] = None
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
**kwargs, **kwargs,
): ):
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.config.num_hidden_layers,
*mamba_state_shape,
*mamba_state_dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
positions, positions,
mamba_cache_params,
intermediate_tensors, intermediate_tensors,
inputs_embeds, inputs_embeds,
) )
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@ -9,19 +9,15 @@ import torch
from torch import nn from torch import nn
from transformers import GraniteMoeHybridConfig from transformers import GraniteMoeHybridConfig
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .granitemoe import GraniteMoeMoE from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP from .granitemoeshared import GraniteMoeSharedMLP
@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) self.mamba(hidden_states, output)
hidden_states = residual + output * self.residual_multiplier hidden_states = residual + output * self.residual_multiplier
residual = hidden_states residual = hidden_states
@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
num_attn += 1 num_attn += 1
hidden_states, residual = layer(positions=positions,
layer_mamba_cache_params = None
if isinstance(
layer,
GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual)
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
head_dim=hf_config.mamba_d_head, head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
scale=1 / scale=1 /
self.config.logits_scaling) self.config.logits_scaling)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
mamba_cache_params = None hidden_states = self.model(input_ids, positions, intermediate_tensors,
if not envs.VLLM_USE_V1: inputs_embeds)
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@ -9,7 +9,6 @@ import torch
from torch import nn from torch import nn
from transformers import JambaConfig from transformers import JambaConfig
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual) hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params) self.mamba(hidden_states, output)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states) hidden_states = self.feed_forward(hidden_states)
@ -333,7 +328,6 @@ class JambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -348,24 +342,11 @@ class JambaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
kv_cache_index = 0
mamba_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None hidden_states, residual = layer(positions=positions,
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache_index += 1
if isinstance(layer,
JambaMambaDecoderLayer) and mamba_cache_params:
current_state_layer = mamba_cache_index
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
mamba_cache_index += 1
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual)
mamba_cache_params=layer_mamba_cache_params)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
state_dtype = self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_layers, *state_shape,
*state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_size=hf_config.mamba_expand * hidden_size, intermediate_size=hf_config.mamba_expand * hidden_size,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=envs.VLLM_USE_V1,
) )
def compute_logits( def compute_logits(

View File

@ -8,7 +8,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import Lfm2Config from transformers import Lfm2Config
from vllm import envs
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
self.conv( self.conv(
hidden_states, hidden_states,
output, output,
conv_metadata=None,
) )
hidden_states, residual = self.ffn_norm(output, residual) hidden_states, residual = self.ffn_norm(output, residual)
hidden_states = self.feed_forward(hidden_states) hidden_states = self.feed_forward(hidden_states)
@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int]]: ) -> tuple[tuple[int, int]]:
""" Calculate shapes for LFM2's convolutional cache. """ Calculate shapes for LFM2's convolutional cache.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
tp_world_size=parallel_config.tensor_parallel_size, tp_world_size=parallel_config.tensor_parallel_size,
intermediate_size=hf_config.conv_dim, intermediate_size=hf_config.conv_dim,
conv_kernel=hf_config.conv_L_cache, conv_kernel=hf_config.conv_L_cache,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
assert (not cache_config.enable_prefix_caching assert (not cache_config.enable_prefix_caching
), "Lfm2 currently does not support prefix caching" ), "Lfm2 currently does not support prefix caching"
assert envs.VLLM_USE_V1, (
"Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1")
super().__init__() super().__init__()
self.config = config self.config = config

View File

@ -8,7 +8,6 @@ import torch
from torch import nn from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm import envs
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState, from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree, SupportsPP) IsAttentionFree, SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params) self.mixer(hidden_states, output)
return output, residual return output, residual
@ -134,7 +129,6 @@ class MambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -151,17 +145,9 @@ class MambaModel(nn.Module):
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions=positions,
layer_cache_params = None
if mamba_cache_params is not None:
layer_cache_params = mamba_cache_params.at_layer_idx(
i - self.start_layer)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual)
mamba_cache_params=layer_cache_params)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
mamba_cache_params = None hidden_states = self.backbone(input_ids, positions,
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
state_dtype = self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_layers, *state_shape,
*state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
tp_world_size=parallel_config.tensor_parallel_size, tp_world_size=parallel_config.tensor_parallel_size,
intermediate_size=hf_config.intermediate_size, intermediate_size=hf_config.intermediate_size,
state_size=hf_config.state_size, state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel, conv_kernel=hf_config.conv_kernel)
use_v1=envs.VLLM_USE_V1)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs( return self.mamba_cache.copy_inputs_before_cuda_graphs(

View File

@ -8,16 +8,11 @@ import torch
from torch import nn from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState, from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree) IsAttentionFree)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) self.mixer(hidden_states, output)
return output, residual return output, residual
@ -137,7 +127,6 @@ class Mamba2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -152,25 +141,10 @@ class Mamba2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(positions=positions,
positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual)
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer) if mamba_cache_params else None,
mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
head_dim=hf_config.head_dim, head_dim=hf_config.head_dim,
state_size=hf_config.state_size, state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel, conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.backbone(input_ids, positions,
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states

View File

@ -1,83 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MambaCacheParams:
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MambaCacheParams(self.conv_state[layer_idx],
self.ssm_state[layer_idx],
self.state_indices_tensor)
class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int,
conv_state_shape: tuple[int, int],
temporal_state_shape: tuple[int, int],
conv_state_dtype: torch.dtype,
temporal_state_dtype: torch.dtype):
self.conv_state_dtype = conv_state_dtype
self.temporal_state_dtype = temporal_state_dtype
# Determine max batch size to set size of MambaCache
max_batch_size = vllm_config.scheduler_config.max_num_seqs
if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
# Initialize parent class
super().__init__(max_batch_size)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
(conv_state_shape[1], conv_state_shape[0]),
dtype=self.conv_state_dtype,
device="cuda").transpose(-1, -2)
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=self.temporal_state_dtype,
device="cuda")
self._mamba_cache = (conv_state, temporal_state)
@property
def cache(self):
return self._mamba_cache
def _copy_cache(self, from_index: int, to_index: int):
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
"""
Return the tensors for the current run's conv and ssm state.
"""
cache_tensors, state_indices_tensor = super().current_run_tensors(
**kwargs)
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
state_indices_tensor)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")

View File

@ -1,36 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MinimaxCacheParams:
minimax_cache: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
self.state_indices_tensor)
class MinimaxCacheManager(ConstantSizeCache):
def __init__(self, dtype, cache_shape):
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
self._minimax_cache = torch.empty(size=cache_shape,
dtype=dtype,
device="cuda")
@property
def cache(self):
return self._minimax_cache
def _copy_cache(self, from_index: int, to_index: int):
assert len(self.cache) > 0
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)

View File

@ -14,7 +14,6 @@ import torch.distributed
from torch import nn from torch import nn
from transformers import MiniMaxConfig from transformers import MiniMaxConfig
from vllm import envs
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid from .interfaces import HasInnerState, IsHybrid
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
def forward(self, def forward(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: Union[list[dict], Optional[torch.Tensor]],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
is_warmup: bool = False, is_warmup: bool = False,
@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_states=layernorm_output, hidden_states=layernorm_output,
output=self_attention_output, output=self_attention_output,
positions=positions, positions=positions,
kv_caches=kv_caches,
) )
residual = residual * self.layernorm_attention_alpha residual = residual * self.layernorm_attention_alpha
@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module):
self._dtype = _dummy.dtype self._dtype = _dummy.dtype
del _dummy del _dummy
if not envs.VLLM_USE_V1:
self.minimax_cache = MinimaxCacheManager(
dtype=torch.float32, cache_shape=self.cache_shape)
norm_kwargs = {} norm_kwargs = {}
if hasattr(config, "rms_norm_eps"): if hasattr(config, "rms_norm_eps"):
norm_kwargs["eps"] = config.rms_norm_eps norm_kwargs["eps"] = config.rms_norm_eps
@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module):
**kwargs) -> Union[torch.Tensor, IntermediateTensors]: **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if not envs.VLLM_USE_V1:
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
else:
minimax_cache_params = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is None: if inputs_embeds is None:
@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
minimax_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
_caches = None
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
minimax_cache_index += 1
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states=hidden_states, hidden_states=hidden_states,
positions=positions, positions=positions,
kv_caches=_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
residual=residual, residual=residual,
) )
@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, ...], ...]: ) -> tuple[tuple[int, ...], ...]:
"""Calculate shape for MiniMaxText01LinearAttention cache. """Calculate shape for MiniMaxText01LinearAttention cache.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:

View File

@ -23,21 +23,17 @@ from typing import Optional
import torch import torch
from torch import nn from torch import nn
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
@ -49,14 +45,11 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsLoRA, SupportsPP, SupportsLoRA, SupportsPP,
SupportsQuant) SupportsQuant)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix) make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronHConfig from vllm.transformers_utils.configs import NemotronHConfig
from vllm.utils import LayerBlockType
class NemotronHMLP(nn.Module): class NemotronHMLP(nn.Module):
@ -181,8 +174,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
@ -192,7 +183,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) self.mixer(hidden_states, output)
return output, residual return output, residual
@ -370,22 +361,10 @@ class NemotronHModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -398,22 +377,11 @@ class NemotronHModel(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
residual = None residual = None
num_non_mamba_layers = 0
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if isinstance(layer,
NemotronHMambaDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_non_mamba_layers)
else:
num_non_mamba_layers += 1
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -508,13 +476,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
@ -533,7 +499,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim=hf_config.mamba_head_dim, head_dim=hf_config.mamba_head_dim,
state_size=hf_config.ssm_state_size, state_size=hf_config.ssm_state_size,
conv_kernel=hf_config.conv_kernel, conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -566,8 +531,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
@ -584,40 +547,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
mamba_cache_params = None hidden_states = self.model(input_ids, positions, intermediate_tensors,
if not envs.VLLM_USE_V1: inputs_embeds)
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@ -1,731 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
logger = init_logger(__name__)
class SwiGLUActivation(nn.Module):
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return x1 * nn.functional.silu(x2)
class SambaYMLP(nn.Module):
"""Gated Linear Unit.
Reference:
Language Modeling with Gated Convolutional Networks.
https://arxiv.org/pdf/1612.08083v3.pdf.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.fc1 = nn.Linear(config.hidden_size,
2 * config.intermediate_size,
bias=False)
self.fc2 = nn.Linear(config.intermediate_size,
config.hidden_size,
bias=False)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
y = self.fc1(hidden_states)
gate, y = y.chunk(2, dim=-1)
y = y * self.activation_fn(gate)
return self.fc2(y)
def get_virtual_engine():
forward_context: ForwardContext = get_forward_context()
return forward_context.virtual_engine
class SambaYAttention(nn.Module):
def __init__(self,
config,
layer_idx: Optional[int] = None,
yoco_cross: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing "
"a `layer_idx` is not recommended and will lead to errors "
"during the forward call if caching is used. Please make "
"sure to provide a `layer_idx` when creating this class.")
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.yoco_cross = yoco_cross
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError("hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size} and "
f"`num_heads`: {self.num_heads}).")
op_size = self.num_heads * self.head_dim + 2 * (
self.num_key_value_heads * self.head_dim)
self.out_proj = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=True)
if yoco_cross:
self.Wqkv = nn.Linear(self.hidden_size,
self.num_heads * self.head_dim,
bias=True)
else:
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
# disable sliding window for the second half of the model
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
sliding_window = config.sliding_window if is_sliding else None
assert self.num_heads % 2 == 0, 'num_heads should be even'
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
self.lambda_init = self.lambda_init_fn(layer_idx)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.subln = nn.RMSNorm(2 * self.head_dim,
eps=1e-5,
elementwise_affine=True)
params = {
'differential_flash_attention_config': {
'lambda_init': self.lambda_init,
'lambda_q1': self.lambda_q1,
'lambda_k1': self.lambda_k1,
'lambda_q2': self.lambda_q2,
'lambda_k2': self.lambda_k2,
"subln": self.subln,
}
}
if yoco_cross:
kv_shared_layer_index = config.num_hidden_layers // 2 + 1
kv_sharing_target_layer_name = \
f"model.layers.{kv_shared_layer_index}.self_attn.attn"
else:
kv_sharing_target_layer_name = None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.head_dim**-0.5,
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
**params)
assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\
"DIFFERENTIAL_FLASH_ATTN required"
def lambda_init_fn(self, depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
def forward(
self,
hidden_states: torch.Tensor,
):
if not self.yoco_cross: # need to generate kv-cache
qkv = self.Wqkv(hidden_states)
q, k, v = qkv.split([
self.hidden_size, self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim
],
dim=-1)
attn_output = self.attn(q, k, v)
else: # reuse the kv cache, full attention
q = self.Wqkv(hidden_states)
attn_output = self.attn(q, None, None)
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
return self.out_proj(attn_output)
class Phi4Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random", # difference
dt_scale=1.0, # difference
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
yoco_cross=False,
yoco_kv=False,
):
factory_kwargs = {"params_dtype": dtype} # difference
super().__init__()
self.yoco_cross = yoco_cross
self.yoco_kv = yoco_kv
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model /
16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.swiGluActivation = SwiGLUActivation()
if self.yoco_cross:
self.in_proj = MergedColumnParallelLinear(self.d_model,
[self.d_inner],
bias=bias,
**factory_kwargs)
self.out_proj = RowParallelLinear(self.d_inner,
self.d_model,
bias=bias,
**factory_kwargs)
return
self.conv1d = ColumnParallelLinear(
input_size=d_conv,
output_size=self.d_inner,
bias=conv_bias,
params_dtype=dtype,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
self.d_model,
[self.d_inner] * 2,
bias=bias,
params_dtype=dtype,
)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
self.d_inner,
self.dt_rank + self.d_state * 2,
bias=False,
params_dtype=dtype,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(
self.dt_rank,
self.d_inner,
bias=True,
skip_bias_add=True,
params_dtype=dtype,
)
# # D "skip" parameter
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
self.A = nn.Parameter(
torch.empty(
self.d_inner,
self.d_state,
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32))
self.out_proj = RowParallelLinear(
self.d_inner,
self.d_model,
bias=bias,
input_is_parallel=True,
params_dtype=dtype,
)
self.activation = "silu"
def forward(self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
yoco_key_values=None) -> torch.Tensor:
if self.yoco_cross:
out = self.in_proj(hidden_states)[0]
out = self.swiGluActivation(yoco_key_values, out)
out = self.out_proj(out)
return out[0], yoco_key_values
# 1. Gated MLP's linear projection
# projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
projected_states = self.in_proj(
hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
[self.dt_rank, self.d_state, self.d_state],
dim=-1,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
# z,
None if self.yoco_kv else gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
# z
# gate.transpose(0, 1),
None if self.yoco_kv else gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
if self.yoco_kv:
# gate = gate.transpose(-1,-2).contiguous()
yoco_key_values = scan_outputs.transpose(-2, -1)
scan_outputs = self.swiGluActivation(scan_outputs, gate)
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states, yoco_key_values
class SambaYDecoderLayer(nn.Module):
def __init__(
self,
config,
layer_idx,
cache_config,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.mlp = SambaYMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.yoco_mb = False
self.yoco_cross = False
if layer_idx >= config.num_hidden_layers // 2:
self.yoco_mb = True
self.yoco_cross = (layer_idx
>= (config.num_hidden_layers // 2 + 2))
self.use_mamba = config.mb_per_layer > 0 and \
layer_idx % config.mb_per_layer == 0
if self.use_mamba:
factory_kwargs = {"dtype": None}
self.attn = Phi4Mamba(config.hidden_size,
layer_idx=layer_idx,
yoco_cross=self.yoco_cross,
yoco_kv=self.yoco_mb,
**factory_kwargs)
else:
self.attn = SambaYAttention(config,
layer_idx=layer_idx,
yoco_cross=self.yoco_cross,
cache_config=cache_config,
prefix=f"{prefix}.self_attn")
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
ssm_output: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.use_mamba:
assert mamba_cache_params is not None
else:
assert mamba_cache_params is None
residual = hidden_states
hidden_states = self.input_layernorm(
hidden_states.to(dtype=self.input_layernorm.weight.dtype))
if self.use_mamba:
attn_outputs, ssm_output = self.attn(hidden_states,
attn_metadata,
mamba_cache_params,
yoco_key_values=ssm_output)
residual = residual.to(torch.float32)
else:
attn_outputs = self.attn(hidden_states, )
hidden_states = residual + attn_outputs
residual = hidden_states
hidden_states = self.post_attention_layernorm(
hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, ssm_output
class SambaYModel(nn.Module):
def __init__(self,
config,
cache_config=None,
quant_config=None,
lora_config=None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
# Pipeline parallel is not supported since the second half of
# the layers share the kv cache.
if get_pp_group().world_size != 1:
raise ValueError("Pipeline Parallel not supported")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: SambaYDecoderLayer(config,
int(prefix.split('.')[-1]),
cache_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
mamba_state_idx = 0
ssm_output = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i == self.config.num_hidden_layers // 2 + 2:
# profile run
kv_cache_idx = self.config.num_hidden_layers // 2 + 1
cache_layer = self.layers[kv_cache_idx]
kv_cache = cache_layer.attn.attn.kv_cache
if kv_cache[0].numel() == 0:
break
# Starting from this layer, we do not need to calculate
# the kv cache since we reuse the kv cache from last layer.
# If in prefill phase, we can <s>prune></s> truncate
# the hidden state to save computation cost.
if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1:
selected_token_indices = torch.cumsum(
attn_metadata.seq_lens_tensor, dim=0) - 1
hidden_states = hidden_states.index_select(
0, selected_token_indices)
ssm_output = ssm_output.index_select(
0, selected_token_indices)
if layer.use_mamba:
if i < self.config.num_hidden_layers // 2 or \
not layer.yoco_cross:
mamba_cache = mamba_cache_params.at_layer_idx(
mamba_state_idx)
mamba_state_idx += 1
else:
mamba_cache = mamba_cache_params.at_layer_idx(
mamba_state_idx - 1)
hidden_states, ssm_output = layer(hidden_states,
positions,
attn_metadata,
mamba_cache,
ssm_output=ssm_output)
else:
hidden_states, ssm_output = layer(
hidden_states,
positions,
attn_metadata,
None, # mamba_cache_params
ssm_output=ssm_output)
hidden_states = self.final_layernorm(
hidden_states.to(dtype=self.final_layernorm.weight.dtype))
return hidden_states
class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
quant_config = vllm_config.quant_config
scheduler_config = vllm_config.scheduler_config
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
# Prefix caching and chunked prefill is not supported for this model.
assert not cache_config.enable_prefix_caching, \
"Phi4flash currently does not support prefix caching"
assert not scheduler_config.chunked_prefill_enabled, \
"Phi4Flash currently does not support prefix caching"
super().__init__()
self.config = config
self.model_config = vllm_config.model_config
self.scheduler_config = scheduler_config
self.model = SambaYModel(config,
cache_config=cache_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.embedding_bias = None
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logits_as_input=False)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers \
// 2 // self.config.mb_per_layer + 1
self.mamba_cache = MambaCacheManager(
self.vllm_config,
num_mamba_layers,
*self._get_mamba_cache_shape(),
self.lm_head.weight.dtype,
self.lm_head.weight.dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
attn_metadata = get_forward_context().attn_metadata
# input_ids and hidden_states isn't a one-to-one mapping in prefill
# stage due to YOCO optimization.
hidden_states = self.model(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors,
inputs_embeds)
return hidden_states
def _get_mamba_cache_shape(
self
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
mamba_expand = self.config.mamba_expand # 2
mamba_d_conv = self.config.mamba_d_conv # 4
mamba_d_state = self.config.mamba_d_state # 16
conv_state_shape = (
mamba_expand * hidden_size // world_size,
mamba_d_conv - 1,
)
temporal_state_shape = (
mamba_expand * hidden_size // world_size,
mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
processed_logits = self.logits_processor(
self.lm_head,
hidden_states,
self.embedding_bias,
)
return processed_logits
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
):
weights = {name: weight for name, weight in weights}
adjusted_weights = {}
for name, weight in weights.items():
if "A_log" in name:
name = name.replace("A_log", "A")
weight = -torch.exp(weight.float())
if "inner_cross_attn." in name:
name = name.replace("inner_cross_attn.", "")
adjusted_weights[name] = weight
adjusted_weights["lm_head.weight"] = weights[
"model.embed_tokens.weight"]
loaded_params: set[str] = set()
for name, param in self.named_parameters():
weight = adjusted_weights.get(name)
if weight is not None and weight.shape != param.shape:
logger.warning("Shape mismatch: %s %s %s", name, weight.shape,
param.shape)
loaded_params.add(name)
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
strict=False)
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
return loaded_params

View File

@ -12,7 +12,6 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
@ -29,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata, update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
@ -47,15 +44,13 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader) composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsPP) SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
is_pp_missing_parameter, make_empty_intermediate_tensors_factory, is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix) make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
@ -194,16 +189,12 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self.chunk_size = self.config.mamba_chunk_size self.chunk_size = self.config.mamba_chunk_size
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path # The tuple is (conv_state, ssm_state)
# only runs for v1, we have to do this to unify with the interface self.kv_cache = (torch.tensor([]), torch.tensor([]))
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert self.chunk_size != -1, "chunk_size must be set for v1" assert self.chunk_size != -1, "chunk_size must be set for v1"
self.prefix = prefix self.prefix = prefix
@ -227,8 +218,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
pass pass
@ -237,14 +226,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata)
else:
torch.ops.vllm.plamo2_mamba_mixer( torch.ops.vllm.plamo2_mamba_mixer(
hidden_states, hidden_states,
output, output,
@ -255,41 +238,31 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
forward_context = get_forward_context() forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton # attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they # modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration # stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
else: has_initial_states_p = attn_metadata.has_initial_states_p
conv_state = mamba_cache_params.conv_state prep_initial_states = attn_metadata.prep_initial_states
ssm_state = mamba_cache_params.ssm_state chunk_size = attn_metadata.chunk_size
state_indices_tensor = mamba_cache_params.state_indices_tensor seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
# Common members between V1 metadata and V0 metadata chunk_offsets_p = attn_metadata.chunk_offsets_p
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states) projected_states = self.in_proj(hidden_states)
@ -299,8 +272,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2)) self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None: if attn_metadata is None:
# V1 profile run # profile run
hidden_states = (hidden_states.transpose(0, 1).clone().transpose( hidden_states = (hidden_states.transpose(0, 1).clone().transpose(
0, 1)).contiguous() 0, 1)).contiguous()
output[:] = self.out_proj(hidden_states) output[:] = self.out_proj(hidden_states)
@ -316,7 +289,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill # NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input # Separate prefill and decode by splitting varlen input
# Split along token dimension # Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_d, hidden_states_p = torch.split( hidden_states_d, hidden_states_p = torch.split(
hidden_states[:num_actual_tokens], hidden_states[:num_actual_tokens],
[num_decodes, num_prefill_tokens], [num_decodes, num_prefill_tokens],
@ -334,24 +306,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
query_start_loc_p = ( query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] - attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None) num_decodes if has_prefill else None)
else:
hidden_states_p, hidden_states_d = torch.split(
hidden_states,
[num_prefill_tokens, num_decodes],
dim=0,
)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decodes],
dim=0)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill # Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs # and decode outputs
@ -363,18 +317,11 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device, device=hidden_states.device,
) )
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out, preallocated_ssm_out,
[num_decodes, num_prefill_tokens], [num_decodes, num_prefill_tokens],
dim=0, dim=0,
) )
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests # Process prefill requests
if has_prefill: if has_prefill:
@ -383,9 +330,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
# pointed to by "state_indices_tensor" # pointed to by "state_indices_tensor"
x = hidden_states_p.transpose( x = hidden_states_p.transpose(
0, 1) # this is the form that causal-conv see 0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_p = causal_conv1d_fn( hidden_states_p = causal_conv1d_fn(
x, x,
conv_weights, conv_weights,
@ -394,7 +338,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_states=conv_state, conv_states=conv_state,
has_initial_state=has_initial_states_p, has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p, cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata, metadata=attn_metadata,
query_start_loc=query_start_loc_p) query_start_loc=query_start_loc_p)
hidden_states_p = hidden_states_p.transpose(0, 1) hidden_states_p = hidden_states_p.transpose(0, 1)
hidden_states_p = hidden_states_p[:num_prefill_tokens] hidden_states_p = hidden_states_p[:num_prefill_tokens]
@ -470,7 +414,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
-1, self.num_heads // self.tp_size, self.head_dim) -1, self.num_heads // self.tp_size, self.head_dim)
# - the hidden is reshaped into (bs, num_heads, head_dim) # - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected # - ssm_state's slots will be selected
# using state_indices_tensor_d # using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor # NOTE: final output is an in-place update of out tensor
@ -530,10 +474,7 @@ def plamo2_mamba_mixer(
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, self.forward_cuda(hidden_states=hidden_states, output=output)
output=output,
mamba_cache_params=None,
mamba2_metadata=None)
def plamo2_mamba_mixer_fake( def plamo2_mamba_mixer_fake(
@ -731,8 +672,6 @@ class Plamo2DecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
@ -747,8 +686,6 @@ class Plamo2DecoderLayer(nn.Module):
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
mixer_kwargs = { mixer_kwargs = {
"output": output, "output": output,
"mamba_cache_params": mamba_cache_params,
"mamba2_metadata": mamba2_metadata,
} }
else: else:
mixer_kwargs = { mixer_kwargs = {
@ -790,23 +727,12 @@ class Plamo2Decoder(torch.nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor: ) -> torch.Tensor:
mamba_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None
if layer.is_mamba and mamba_cache_params is not None:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
mamba_cache_index)
mamba_cache_index += 1
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
return hidden_states, residual return hidden_states, residual
@ -844,7 +770,6 @@ class Plamo2Model(torch.nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -859,23 +784,10 @@ class Plamo2Model(torch.nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
if not envs.VLLM_USE_V1:
attn_metadata: AttentionMetadata = get_forward_context(
).attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
hidden_states, residual = self.layers( hidden_states, residual = self.layers(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@ -925,9 +837,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size) self.config.vocab_size)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
@ -942,39 +851,11 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = self.get_mamba_state_shape_from_config( hidden_states = self.model(input_ids, positions, intermediate_tensors,
self.vllm_config, use_v1=False) inputs_embeds)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
@classmethod @classmethod
def get_mamba_state_dtype_from_config( def get_mamba_state_dtype_from_config(
cls, cls,
@ -991,12 +872,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
- conv_state_shape: Shape for convolutional state cache - conv_state_shape: Shape for convolutional state cache
@ -1015,7 +894,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
head_dim=hf_config.hidden_size_per_head, head_dim=hf_config.hidden_size_per_head,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
def compute_logits( def compute_logits(

View File

@ -11,7 +11,6 @@ from einops import rearrange
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from vllm import envs
from vllm.attention import Attention, AttentionBackend, AttentionMetadata from vllm.attention import Attention, AttentionBackend, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
@ -35,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_mixer2 import ( from vllm.model_executor.layers.mamba.mamba_mixer2 import (
mamba_v2_sharded_weight_loader) mamba_v2_sharded_weight_loader)
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
@ -51,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader) default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -198,14 +195,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.gated_delta_net_state_shape( return MambaStateShapeCalculator.gated_delta_net_state_shape(
self.tp_size, self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
self.num_k_heads, self.head_v_dim, self.conv_kernel_size, self.num_spec)
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
self.conv_kernel_size,
self.num_spec,
use_v1=True)
def __init__( def __init__(
self, self,
@ -394,7 +385,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
cache_params: Optional[MambaCacheParams] = None,
): ):
return torch.ops.vllm.gdn_attention( return torch.ops.vllm.gdn_attention(
hidden_states, hidden_states,
@ -416,7 +406,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, GDNAttentionMetadata) assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc spec_query_start_loc = attn_metadata.spec_query_start_loc
@ -479,12 +468,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 2.2: process the remaining part # 2.2: process the remaining part
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(mixed_qkv_non_spec_T,
non_spec_query_start_loc,
conv_metadata)
# - "cache_indices" updates the conv_state cache in positions # - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor" # pointed to by "state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn( mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec_T, mixed_qkv_non_spec_T,
conv_weights, conv_weights,
@ -494,7 +479,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
has_initial_state=has_initial_state, has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor, cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc, query_start_loc=non_spec_query_start_loc,
metadata=conv_metadata, metadata=attn_metadata,
).transpose(0, 1) ).transpose(0, 1)
elif attn_metadata.num_decodes > 0: elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv_non_spec = causal_conv1d_update(
@ -1075,7 +1060,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Qwen3Next currently does not support prefix caching" "Qwen3Next currently does not support prefix caching"
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
self.quant_config = vllm_config.quant_config self.quant_config = vllm_config.quant_config
super().__init__() super().__init__()
@ -1195,14 +1179,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
num_spec = (vllm_config.speculative_config.num_speculative_tokens num_spec = (vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0) if vllm_config.speculative_config else 0)
return MambaStateShapeCalculator.gated_delta_net_state_shape( return MambaStateShapeCalculator.gated_delta_net_state_shape(
tp_size, tp_size, hf_config.linear_num_key_heads,
hf_config.linear_num_key_heads, hf_config.linear_num_value_heads, hf_config.linear_key_head_dim,
hf_config.linear_num_value_heads, hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim,
hf_config.linear_key_head_dim, num_spec)
hf_config.linear_value_head_dim,
hf_config.linear_conv_kernel_dim,
num_spec,
use_v1=True)
def compute_logits( def compute_logits(
self, self,

View File

@ -134,7 +134,6 @@ _TEXT_GENERATION_MODELS = {
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),

View File

@ -15,12 +15,10 @@ import torch
from torch import nn from torch import nn
from transformers import Zamba2Config from transformers import Zamba2Config
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -29,8 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
@ -39,8 +35,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid from .interfaces import HasInnerState, IsHybrid
@ -515,8 +509,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
transformer_hidden_states: Optional[torch.Tensor] = None, transformer_hidden_states: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
original_hidden_states: Optional[torch.Tensor] = None, original_hidden_states: Optional[torch.Tensor] = None,
@ -525,8 +517,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
Args: Args:
hidden_states: Input tensor [batch_size, seq_len, hidden_size] hidden_states: Input tensor [batch_size, seq_len, hidden_size]
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
transformer_hidden_states: Optional output from transformer path transformer_hidden_states: Optional output from transformer path
Added to input if provided (used in hybrid architecture) Added to input if provided (used in hybrid architecture)
positions: Optional position IDs (unused in Mamba) positions: Optional position IDs (unused in Mamba)
@ -555,8 +545,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
self.mamba( self.mamba(
hidden_states, hidden_states,
output, output,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
# residual connection after mamba # residual connection after mamba
@ -607,8 +595,6 @@ class Zamba2HybridLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
original_hidden_states: torch.Tensor, original_hidden_states: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass through the hybrid layer. """Forward pass through the hybrid layer.
@ -623,8 +609,6 @@ class Zamba2HybridLayer(nn.Module):
original_hidden_states: Original input for transformer residual original_hidden_states: Original input for transformer residual
connection connection
positions: Position IDs for positional embeddings positions: Position IDs for positional embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
Returns: Returns:
Output tensor combining transformer and Mamba representations Output tensor combining transformer and Mamba representations
@ -644,8 +628,6 @@ class Zamba2HybridLayer(nn.Module):
layer_outputs = self.mamba_decoder( layer_outputs = self.mamba_decoder(
hidden_states, hidden_states,
transformer_hidden_states=transformer_hidden_states, transformer_hidden_states=transformer_hidden_states,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
return layer_outputs return layer_outputs
@ -752,7 +734,6 @@ class Zamba2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Forward pass through the model. """Forward pass through the model.
@ -760,8 +741,6 @@ class Zamba2Model(nn.Module):
Args: Args:
input_ids: Input token IDs input_ids: Input token IDs
positions: Position IDs for embeddings positions: Position IDs for embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
inputs_embeds: Optional pre-computed input embeddings inputs_embeds: Optional pre-computed input embeddings
Returns: Returns:
@ -773,33 +752,13 @@ class Zamba2Model(nn.Module):
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
# Process through layers # Process through layers
original_hidden_states = torch.clone(hidden_states) original_hidden_states = torch.clone(hidden_states)
for layer_idx, layer in enumerate(self.layers): for layer_idx, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
and mamba_cache_params):
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
layer_idx)
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
original_hidden_states=original_hidden_states, original_hidden_states=original_hidden_states,
positions=positions, positions=positions,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
hidden_states = layer_outputs hidden_states = layer_outputs
@ -870,13 +829,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
@ -896,7 +853,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
head_dim=hf_config.mamba_headdim, head_dim=hf_config.mamba_headdim,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
@ -945,9 +901,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
# Tie weights with input embeddings if using same dimensions # Tie weights with input embeddings if using same dimensions
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
# Initialize logits processing and sampling # Initialize logits processing and sampling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
@ -977,61 +930,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
Returns: Returns:
Output hidden states Output hidden states
""" """
# Initialize Mamba cache if needed
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
# Forward pass through model # Forward pass through model
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
positions, positions,
mamba_cache_params,
inputs_embeds, inputs_embeds,
) )
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(
self, input_buffers: dict[str, torch.Tensor],
**kwargs: Any) -> dict[str, torch.Tensor]:
"""Copy inputs before CUDA graph capture.
Args:
input_buffers: Dictionary of input tensors
**kwargs: Additional arguments passed to cache manager
Returns:
Updated input buffers
"""
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(
self, batch_size: int) -> dict[str, torch.Tensor]:
"""Get inputs for sequence-length-agnostic graph capture.
Args:
batch_size: Size of batch to capture
Returns:
Dictionary of capture inputs
"""
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@ -12,6 +12,7 @@ from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport, from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills) split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@ -52,7 +53,6 @@ class GDNAttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d # The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None
@ -134,6 +134,7 @@ class GDNAttentionMetadataBuilder(
context_lens = m.num_computed_tokens_cpu context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device) context_lens_tensor = context_lens.to(query_start_loc.device)
seq_lens_tensor = m.seq_lens seq_lens_tensor = m.seq_lens
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (not self.use_spec_decode or num_draft_tokens is None if (not self.use_spec_decode or num_draft_tokens is None
or num_draft_tokens.sum().item() == 0): or num_draft_tokens.sum().item() == 0):
@ -210,6 +211,8 @@ class GDNAttentionMetadataBuilder(
has_initial_state = context_lens_tensor > 0 has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks] has_initial_state = has_initial_state[~spec_sequence_masks]
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(non_spec_query_start_loc)
else: else:
has_initial_state = None has_initial_state = None
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \ num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
@ -297,6 +300,9 @@ class GDNAttentionMetadataBuilder(
spec_sequence_masks=spec_sequence_masks, spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks, spec_token_masks=spec_token_masks,
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
) )
return attn_metadata return attn_metadata

View File

@ -7,11 +7,12 @@ from typing import Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backends.mamba_attn import ( from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadataBuilder) BaseMambaAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills) split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
@ -131,7 +132,6 @@ class Mamba2AttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d # The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None
@ -161,6 +161,9 @@ class Mamba2AttentionMetadataBuilder(
has_initial_states_p = None has_initial_states_p = None
prep_initial_states = False prep_initial_states = False
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
@ -198,6 +201,9 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p, self.chunk_size, query_start_loc_p, self.chunk_size,
num_prefill_tokens)) num_prefill_tokens))
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(query_start_loc_p)
elif num_decodes <= self.decode_cudagraph_max_bs: elif num_decodes <= self.decode_cudagraph_max_bs:
# Pad state tensor for CUDA graph # Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
@ -220,5 +226,8 @@ class Mamba2AttentionMetadataBuilder(
chunk_indices_p=chunk_indices_p, chunk_indices_p=chunk_indices_p,
chunk_offsets_p=chunk_offsets_p, chunk_offsets_p=chunk_offsets_p,
state_indices_tensor=state_indices_tensor, state_indices_tensor=state_indices_tensor,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
) )
return attn_metadata return attn_metadata

View File

@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills) split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@ -33,7 +34,6 @@ class ShortConvAttentionMetadata:
# For causal_conv1d # For causal_conv1d
nums_dict: Optional[dict] = None nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None
@ -57,6 +57,9 @@ class ShortConvAttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills( split_decodes_and_prefills(
common_attn_metadata, common_attn_metadata,
@ -70,6 +73,12 @@ class ShortConvAttentionMetadataBuilder(
has_initial_states = has_initial_states_cpu.to( has_initial_states = has_initial_states_cpu.to(
query_start_loc.device) query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-num_prefills - 1:] - num_decode_tokens
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(query_start_loc_p)
attn_metadata = ShortConvAttentionMetadata( attn_metadata = ShortConvAttentionMetadata(
num_prefills=num_prefills, num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
@ -78,5 +87,8 @@ class ShortConvAttentionMetadataBuilder(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
has_initial_states=has_initial_states, has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor, state_indices_tensor=state_indices_tensor,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
) )
return attn_metadata return attn_metadata

View File

@ -34,6 +34,8 @@ logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"] KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None _KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
PAD_SLOT_ID = -1
def is_valid_kv_cache_layout(value: str) -> bool: def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType) return value in get_args(KVCacheLayoutType)
@ -838,3 +840,52 @@ def create_fast_prefill_custom_backend(
builder_cls=FastPrefillAttentionBuilder) builder_cls=FastPrefillAttentionBuilder)
return attn_backend return attn_backend
def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
# Needed for causal_conv1d
seqlens = query_start_loc_p.diff().to('cpu')
nums_dict = {} # type: ignore
batch_ptr = None
token_chunk_offset_ptr = None
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if batch_ptr is None:
# Update default value after class definition
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
else:
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
batch_ptr[0:mlist_len].copy_(mlist)
token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr
) # type: ignore
return nums_dict, batch_ptr, token_chunk_offset_ptr