diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 4b4cebb6a31c2..7f54d98527686 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -395,7 +395,7 @@ th { | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | -| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | ✅︎ | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index f71805436a6ae..525f740d12a7f 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -110,7 +110,7 @@ Models using selective state-space mechanisms instead of standard transformer at Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported. Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, -`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`, `Plamo2ForCausalLM`). Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`). diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 9e97e3fa65775..b44ddc61b6c8c 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -25,8 +25,7 @@ SSM_MODELS = [ HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", - # skipping until vLLM implementation issues are resolved - # "pfnet/plamo-2-1b", + "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", "ibm-granite/granite-4.0-tiny-preview", @@ -37,6 +36,7 @@ HYBRID_MODELS = [ V1_SUPPORTED_MODELS = [ "state-spaces/mamba-130m-hf", "ai21labs/Jamba-tiny-dev", + "pfnet/plamo-2-1b", "yujiepan/mamba2-codestral-v0.1-tiny-random", "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", @@ -47,6 +47,7 @@ V1_SUPPORTED_MODELS = [ FULL_CUDA_GRAPH_MODELS = [ "ai21labs/Jamba-tiny-dev", + "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", ] diff --git a/tests/models/registry.py b/tests/models/registry.py index 4cf3dd6e08ced..f1f61c6151349 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -287,8 +287,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", - max_transformers_version="4.53", - transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501 trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", max_transformers_version="4.53", diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 28ad3d2f535d3..677fb069bc07a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -340,6 +340,7 @@ class CompilationConfig: "vllm.mamba_mixer", "vllm.short_conv", "vllm.linear_attention", + "vllm.plamo2_mamba_mixer", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 7f70e44b10a6d..b9869f5e58800 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -3,19 +3,24 @@ """Inference-only PLaMo2 model.""" from collections.abc import Iterable from itertools import islice -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -23,8 +28,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) + Mamba2Metadata, prepare_mamba2_metadata, update_metadata) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -39,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsPP, SupportsV0Only) + SupportsPP) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.models.utils import ( @@ -47,8 +55,10 @@ from vllm.model_executor.models.utils import ( make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType +from vllm.utils import LayerBlockType, direct_register_custom_op +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Only used for type hinting. @@ -73,20 +83,6 @@ class Plamo2Config(PretrainedConfig): # type: ignore vocab_size: int -class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore - - def _init_weights(self, module: torch.nn.Module) -> None: - std = 0.02 - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def is_mamba(config: Plamo2Config, i: int) -> bool: assert config.mamba_step > 1 @@ -99,7 +95,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: # Adapted from: # vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2 # transformers.models.mamba.modeling_mamba.MambaMixer -class Plamo2MambaMixer(nn.Module): +@CustomOp.register(name="plamo2_mamba_mixer") +class Plamo2MambaMixer(MambaBase, CustomOp): def __init__(self, vllm_config: VllmConfig, @@ -108,6 +105,8 @@ class Plamo2MambaMixer(nn.Module): **kwargs) -> None: super().__init__() self.config = vllm_config.model_config.hf_config + self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config self.quant_config = vllm_config.quant_config self.hidden_size = self.config.hidden_size self.ssm_state_size = self.config.mamba_d_state @@ -115,8 +114,6 @@ class Plamo2MambaMixer(nn.Module): self.intermediate_size = (self.config.mamba_num_heads * self.config.hidden_size_per_head) self.tp_size = get_tensor_model_parallel_world_size() - self.intermediate_size_per_tp_worker = \ - self.intermediate_size // self.tp_size self.head_dim = self.config.hidden_size_per_head self.num_heads = self.config.mamba_num_heads self.time_step_rank = max(64, self.hidden_size // 16) @@ -197,6 +194,22 @@ class Plamo2MambaMixer(nn.Module): self.C_norm = RMSNorm(self.ssm_state_size, eps=self.config.rms_norm_eps) + self.chunk_size = self.config.mamba_chunk_size + + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + # The inner tuple is (conv_state, ssm_state) + self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + assert self.chunk_size != -1, "chunk_size must be set for v1" + + self.prefix = prefix + def _project_ssm_parameters(self, hidden_states): ssm_parameters = self.bcdt_proj(hidden_states) B, C, time_step = torch.split( @@ -212,25 +225,76 @@ class Plamo2MambaMixer(nn.Module): dt = self.dt_proj(time_step) return B, C, dt - def forward( + def forward_native( self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: MambaCacheParams, mamba2_metadata: Mamba2Metadata, **kwargs, - ) -> torch.Tensor: + ): + pass + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + if not envs.VLLM_USE_V1: + CustomOp.forward(self, hidden_states, output, mamba_cache_params, + mamba2_metadata) + else: + torch.ops.vllm.plamo2_mamba_mixer( + hidden_states, + output, + self.prefix, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + + forward_context = get_forward_context() # mamba2_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - - num_prefills = attn_metadata.num_prefills # request count - num_decodes = attn_metadata.num_decode_tokens # token count (=request) - num_prefill_tokens = attn_metadata.num_prefill_tokens # token count - has_prefill = num_prefills > 0 - has_decode = num_decodes > 0 + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + mamba2_metadata = attn_metadata + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states_p + prep_initial_states = attn_metadata.prep_initial_states + chunk_size = attn_metadata.chunk_size + seq_idx_p = attn_metadata.seq_idx_p + chunk_indices_p = attn_metadata.chunk_indices_p + chunk_offsets_p = attn_metadata.chunk_offsets_p + else: + conv_state = mamba_cache_params.conv_state + ssm_state = mamba_cache_params.ssm_state + state_indices_tensor = mamba_cache_params.state_indices_tensor + has_initial_states_p = mamba2_metadata.has_initial_states + prep_initial_states = mamba2_metadata.prep_initial_states + chunk_size = mamba2_metadata.chunk_size + seq_idx_p = mamba2_metadata.seq_idx + chunk_indices_p = mamba2_metadata.chunk_indices + chunk_offsets_p = mamba2_metadata.chunk_offsets # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) @@ -240,23 +304,59 @@ class Plamo2MambaMixer(nn.Module): conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if envs.VLLM_USE_V1 and attn_metadata is None: + # V1 profile run + hidden_states = (hidden_states.transpose(0, 1).clone().transpose( + 0, 1)).contiguous() + output[:] = self.out_proj(hidden_states) + return + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + num_actual_tokens = num_prefill_tokens + num_decodes + + # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - hidden_states_p, hidden_states_d = torch.split( - hidden_states, - [num_prefill_tokens, num_decodes], - dim=0, - ) - gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes], - dim=0) - # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - mamba_cache_params.state_indices_tensor, - [num_prefills, num_decodes], - dim=0, - ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] - if has_prefill else None) + if envs.VLLM_USE_V1: + hidden_states_d, hidden_states_p = torch.split( + hidden_states[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + gate_d, gate_p = torch.split(gate[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + else: + hidden_states_p, hidden_states_d = torch.split( + hidden_states, + [num_prefill_tokens, num_decodes], + dim=0, + ) + gate_p, gate_d = torch.split(gate, + [num_prefill_tokens, num_decodes], + dim=0) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + + 1] + if has_prefill else None) # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -268,25 +368,38 @@ class Plamo2MambaMixer(nn.Module): dtype=hidden_states.dtype, device=hidden_states.device, ) - preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( - preallocated_ssm_out, - [num_prefill_tokens, num_decodes], - dim=0, - ) + if envs.VLLM_USE_V1: + preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( + preallocated_ssm_out, + [num_decodes, num_prefill_tokens], + dim=0, + ) + else: + preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( + preallocated_ssm_out, + [num_prefill_tokens, num_decodes], + dim=0, + ) # Process prefill requests if has_prefill: # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions - # pointed to by "mamba_cache_params.state_indices_tensor" + # pointed to by "state_indices_tensor" + x = hidden_states_p.transpose( + 0, 1) # this is the form that causal-conv see + if mamba2_metadata.cu_seqlen is None: + mamba2_metadata = update_metadata(x, query_start_loc_p, + mamba2_metadata) hidden_states_p = causal_conv1d_fn( - hidden_states_p.transpose(0, 1), + x, conv_weights, self.conv1d.bias, activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=mamba2_metadata.has_initial_states, + conv_states=conv_state, + has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, + metadata=mamba2_metadata, query_start_loc=query_start_loc_p) hidden_states_p = hidden_states_p.transpose(0, 1) hidden_states_p = hidden_states_p[:num_prefill_tokens] @@ -299,12 +412,16 @@ class Plamo2MambaMixer(nn.Module): # 3. State Space Model sequence transformation initial_states = None - if (mamba2_metadata.has_initial_states is not None - and mamba2_metadata.prep_initial_states): + if has_initial_states_p is not None and prep_initial_states: # making a copy of the states - initial_states = torch.where( - mamba2_metadata.has_initial_states[:, None, None, None], - mamba_cache_params.ssm_state[state_indices_tensor_p], 0) + if envs.VLLM_USE_V1: + initial_states = torch.where( + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0) + else: + initial_states = torch.where( + has_initial_states_p[:num_prefills, None, None, None], + ssm_state[state_indices_tensor_p], 0) varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, self.num_heads // self.tp_size, @@ -313,15 +430,15 @@ class Plamo2MambaMixer(nn.Module): self.A, B.view(1, num_prefill_tokens, 1, -1), C.view(1, num_prefill_tokens, 1, -1), - chunk_size=mamba2_metadata.chunk_size, + chunk_size=chunk_size, D=self.D, z=gate_p.view(1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim), dt_bias=self.dt_bias, - seq_idx=mamba2_metadata.seq_idx, - chunk_indices=mamba2_metadata.chunk_indices, - chunk_offsets=mamba2_metadata.chunk_offsets, - cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], + seq_idx=seq_idx_p, + chunk_indices=chunk_indices_p, + chunk_offsets=chunk_offsets_p, + cu_seqlens=query_start_loc_p, initial_states=initial_states, return_varlen_states=True, return_final_states=False, @@ -329,18 +446,19 @@ class Plamo2MambaMixer(nn.Module): dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), + state_dtype=ssm_state.dtype, ) # update ssm states # - varlen state is a (batch, nheads, headdim, dstate) tensor - mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state + ssm_state[state_indices_tensor_p] = varlen_state # Process decode requests if has_decode: # 2. Convolution sequence transformation hidden_states_d = causal_conv1d_update( hidden_states_d, - mamba_cache_params.conv_state, + conv_state, conv_weights, self.conv1d.bias, self.activation, @@ -363,8 +481,10 @@ class Plamo2MambaMixer(nn.Module): # - the hidden is reshaped into (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected # using state_indices_tensor_d + + # NOTE: final output is an in-place update of out tensor selective_state_update( - mamba_cache_params.ssm_state, + ssm_state, hidden_states_d, dt, A, @@ -378,11 +498,68 @@ class Plamo2MambaMixer(nn.Module): out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), ) - assert self.num_heads % self.tp_size == 0 # 4. Final linear projection - out = self.out_proj(preallocated_ssm_out) - return out + output[:num_actual_tokens] = self.out_proj(preallocated_ssm_out) + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.mamba2_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.mamba2_state_shape( + intermediate_size=self.intermediate_size, + tp_world_size=get_tensor_model_parallel_world_size(), + n_groups=0, + num_heads=self.num_heads, + head_dim=self.head_dim, + state_size=self.ssm_state_size, + conv_kernel=self.conv_kernel_size, + ) + + @property + def mamba_type(self) -> str: + return "mamba2" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba2_attn import ( + Mamba2AttentionBackend) + return Mamba2AttentionBackend + + +def plamo2_mamba_mixer( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + mamba_cache_params=None, + mamba2_metadata=None) + + +def plamo2_mamba_mixer_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="plamo2_mamba_mixer", + op_func=plamo2_mamba_mixer, + mutates_args=["output"], + fake_impl=plamo2_mamba_mixer_fake, + dispatch_key=current_platform.dispatch_key, +) class DenseMLP(nn.Module): @@ -418,7 +595,6 @@ class DenseMLP(nn.Module): return self.down_proj(h) -@support_torch_compile class Plamo2AttentionMixer(nn.Module): def __init__(self, @@ -575,12 +751,24 @@ class Plamo2DecoderLayer(nn.Module): hidden_states, residual = self.pre_mixer_norm( hidden_states, residual) + if self.is_mamba: + # Plamo2MambaMixer writes output to this tensor + output = torch.empty_like(hidden_states) + mixer_kwargs = { + "output": output, + "mamba_cache_params": mamba_cache_params, + "mamba2_metadata": mamba2_metadata, + } + else: + mixer_kwargs = { + "positions": positions, + } hidden_states = self.mixer( - positions=positions, hidden_states=hidden_states, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, + **mixer_kwargs, ) + if self.is_mamba: + hidden_states = output hidden_states = self.post_mixer_norm(hidden_states) # Fully Connected hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) @@ -591,7 +779,7 @@ class Plamo2DecoderLayer(nn.Module): class Plamo2Decoder(torch.nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} @@ -617,7 +805,7 @@ class Plamo2Decoder(torch.nn.Module): mamba_cache_index = 0 for layer in islice(self.layers, self.start_layer, self.end_layer): layer_mamba_cache_params = None - if layer.is_mamba: + if layer.is_mamba and mamba_cache_params is not None: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( mamba_cache_index) mamba_cache_index += 1 @@ -632,10 +820,11 @@ class Plamo2Decoder(torch.nn.Module): return hidden_states, residual -class Plamo2Model(Plamo2PreTrainedModel): +@support_torch_compile +class Plamo2Model(torch.nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config.model_config.hf_config) + super().__init__() config = vllm_config.model_config.hf_config @@ -653,9 +842,9 @@ class Plamo2Model(Plamo2PreTrainedModel): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") + self.layers = Plamo2Decoder(vllm_config=vllm_config, + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_init() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -679,11 +868,16 @@ class Plamo2Model(Plamo2PreTrainedModel): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) + if not envs.VLLM_USE_V1: + attn_metadata: AttentionMetadata = get_forward_context( + ).attn_metadata + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None hidden_states, residual = self.layers( positions=positions, @@ -701,8 +895,7 @@ class Plamo2Model(Plamo2PreTrainedModel): return hidden_states -class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, - IsHybrid, SupportsV0Only): +class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -712,12 +905,10 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() config = vllm_config.model_config.hf_config scheduler_config = vllm_config.scheduler_config - assert not vllm_config.cache_config.enable_prefix_caching, \ - "PLaMo2 currently does not support prefix caching" - super().__init__(config) self.config = config self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -751,8 +942,6 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - # Initialize weights and apply final processing - self.post_init() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -763,19 +952,27 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = ( + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba)) - self.mamba_cache = MambaCacheManager( - self.vllm_config, - num_mamba_layers, - *self._get_mamba_cache_shape(), - self.lm_head.weight.dtype, - self.lm_head.weight.dtype, - ) + mamba_state_shape = self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) + self.mamba_cache = MambaCacheManager(self.vllm_config, + num_mamba_layers, + *mamba_state_shape, + *mamba_state_dtype) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + else: + # NOTE: mamba_cache_params is not needed for v1 + mamba_cache_params = None hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) @@ -788,21 +985,48 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = (self.config.mamba_num_heads * - self.config.hidden_size_per_head) - conv_state_shape = ( - hidden_size // world_size, - self.config.mamba_d_conv - 1, + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, ) - temporal_state_shape = ( - divide(self.config.mamba_num_heads, world_size), - self.config.hidden_size_per_head, - self.config.mamba_d_state, + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size =\ + hf_config.mamba_num_heads * hf_config.hidden_size_per_head + + return MambaStateShapeCalculator.mamba2_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=0, + num_heads=hf_config.mamba_num_heads, + head_dim=hf_config.hidden_size_per_head, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, ) - return conv_state_shape, temporal_state_shape def compute_logits( self,