diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index eba14e64553e..e4294512338b 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -104,7 +104,6 @@ def test_models( m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - enforce_eager=True, enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/vllm/config.py b/vllm/config.py index 384cb584fa9a..a9720fa3142c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4312,6 +4312,7 @@ class CompilationConfig: self.splitting_ops = [] if self.full_cuda_graph else [ "vllm.unified_attention", "vllm.unified_attention_with_output", + "vllm.mamba_mixer2", ] diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index f3850d31c829..e32b2be4d40e 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -13,7 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -33,6 +33,8 @@ from vllm.model_executor.model_loader.weight_utils import ( LoaderFunction, composed_weight_loader, sharded_weight_loader) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 @@ -424,14 +426,36 @@ class MambaMixer2(MambaBase, CustomOp): def forward_native( self, hidden_states: torch.Tensor, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + mup_vector: Optional[torch.Tensor] = None, ): pass + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + mup_vector: Optional[torch.Tensor] = None, + ): + if not envs.VLLM_USE_V1: + CustomOp.forward(self, hidden_states, output, mamba_cache_params, + mamba2_metadata, mup_vector) + else: + torch.ops.vllm.mamba_mixer2( + hidden_states, + output, + self.prefix, + mup_vector, + ) + def forward_cuda( self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: MambaCacheParams, mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, @@ -517,6 +541,7 @@ class MambaMixer2(MambaBase, CustomOp): num_prefill_tokens = attn_metadata.num_prefill_tokens # token count has_prefill = num_prefills > 0 has_decode = num_decodes > 0 + num_actual_tokens = num_prefill_tokens + num_decodes # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input @@ -524,18 +549,18 @@ class MambaMixer2(MambaBase, CustomOp): # NOTE: V0 put prefill before decode, v1 puts decode before prefill if envs.VLLM_USE_V1: hidden_states_B_C_d, hidden_states_B_C_p = torch.split( - hidden_states_B_C, + hidden_states_B_C[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0, ) dt_d, dt_p = torch.split( - dt, + dt[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0, ) # Split along batch dimension state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor, + state_indices_tensor[:num_actual_tokens], [num_decodes, num_prefills], dim=0, ) @@ -696,11 +721,10 @@ class MambaMixer2(MambaBase, CustomOp): # GatedRMSNorm internally applying SiLU to the gate # SiLU is applied internally before normalization, unlike standard # norm usage - hidden_states = self.norm(hidden_states, gate) + hidden_states = self.norm(hidden_states, gate[:num_actual_tokens]) # 5. Final linear projection - out, _ = self.out_proj(hidden_states) - return out + output[:num_actual_tokens], _ = self.out_proj(hidden_states) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return get_mamba_state_shape( @@ -712,3 +736,36 @@ class MambaMixer2(MambaBase, CustomOp): state_size=self.ssm_state_size, conv_kernel=self.conv_kernel_size, ) + + +def mamba_mixer2( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, + mup_vector: Optional[torch.Tensor] = None, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + mamba_cache_params=None, + mamba2_metadata=None, + mup_vector=mup_vector) + + +def mamba_mixer2_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, + mup_vector: Optional[torch.Tensor] = None, +) -> None: + return + + +direct_register_custom_op( + op_name="mamba_mixer2", + op_func=mamba_mixer2, + mutates_args=["output"], + fake_impl=mamba_mixer2_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index e93d4294a62c..0f5494427634 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -11,6 +11,7 @@ from transformers import BambaConfig from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -122,11 +123,10 @@ class BambaMixerDecoderLayer(nn.Module): hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, mamba_cache_params, - mamba2_metadata) + output = torch.empty_like(hidden_states) + self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @@ -169,7 +169,7 @@ class BambaAttentionDecoderLayer(nn.Module): self.max_position_embeddings = max_position_embeddings if hasattr(config, "partial_rotary_factor"): - rotary_dim = self.head_dim * config.partial_rotary_factor + rotary_dim = int(self.head_dim * config.partial_rotary_factor) elif hasattr(config, "attn_rotary_emb"): rotary_dim = config.attn_rotary_emb # for backward compatibility else: @@ -258,6 +258,7 @@ ALL_DECODER_LAYER_TYPES = { } +@support_torch_compile class BambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 7761de224c9d..6a58b1501fe6 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -10,6 +10,7 @@ from transformers import FalconH1Config from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -179,13 +180,15 @@ class FalconH1SSMDecoderLayer(nn.Module): mamba2_metadata: Mamba2Metadata, **kwargs, ): - hidden_states = self.mamba( + output = torch.empty_like(hidden_states) + self.mamba( hidden_states, + output, mamba_cache_params, mamba2_metadata=mamba2_metadata, mup_vector=self.mup_vector, ) - return hidden_states, residual + return output, residual class FalconH1AttentionDecoderLayer(nn.Module): @@ -398,6 +401,7 @@ class FalconH1ParallelHybrid(nn.Module): return hidden_states +@support_torch_compile class FalconH1Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 1c93e90737ad..59c1dce48ee7 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -11,6 +11,7 @@ from transformers import GraniteMoeHybridConfig from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -104,9 +105,9 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.mamba(hidden_states, mamba_cache_params, - mamba2_metadata) - hidden_states = residual + hidden_states * self.residual_multiplier + output = torch.empty_like(hidden_states) + self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) + hidden_states = residual + output * self.residual_multiplier residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) @@ -307,6 +308,7 @@ ALL_DECODER_LAYER_TYPES = { } +@support_torch_compile class GraniteMoeHybridModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index d812d8cc0a39..adad181617e6 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -10,6 +10,7 @@ from transformers import MambaConfig from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -79,11 +80,12 @@ class Mamba2DecoderLayer(nn.Module): else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params, - mamba2_metadata) - return hidden_states, residual + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) + return output, residual +@support_torch_compile class Mamba2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index cf7b39db1fe3..6a999e2254e7 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -25,6 +25,7 @@ from torch import nn from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -172,9 +173,9 @@ class NemotronHMambaDecoderLayer(nn.Module): else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params, - mamba2_metadata) - return hidden_states, residual + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) + return output, residual class NemotronHAttention(nn.Module): @@ -292,6 +293,7 @@ ALL_DECODER_LAYER_TYPES = { } +@support_torch_compile class NemotronHModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index ebf8dd497f67..7764fd9b9e08 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -17,6 +17,7 @@ from transformers import Zamba2Config from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context @@ -548,14 +549,16 @@ class Zamba2MambaDecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) # Process through Mamba mixer - hidden_states = self.mamba( + output = torch.empty_like(hidden_states) + self.mamba( hidden_states, + output, mamba_cache_params=mamba_cache_params, mamba2_metadata=mamba2_metadata, ) # residual connection after mamba - hidden_states = residual + hidden_states + hidden_states = residual + output return hidden_states @@ -646,6 +649,7 @@ class Zamba2HybridLayer(nn.Module): return layer_outputs +@support_torch_compile class Zamba2Model(nn.Module): """Core Zamba2 model combining transformer and Mamba architectures. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d5449a68bc28..1ee9c070226c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2753,9 +2753,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.vllm_config.speculative_config is not None: raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") - if not self.vllm_config.model_config.enforce_eager: - raise NotImplementedError( - "Mamba with cuda graph is not supported yet.") if self.vllm_config.cache_config.enable_prefix_caching: raise NotImplementedError( "Prefix caching is not supported for Mamba yet.")