diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 831bfb1e939e..ad3db1cf2100 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -373,6 +373,7 @@ th { | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index f8c0eaa8cf3a..2055c44c83cd 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -31,6 +31,7 @@ HYBRID_MODELS = [ "hmellor/tiny-random-BambaForCausalLM", "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", + "LiquidAI/LFM2-1.2B", ] HF_UNSUPPORTED_MODELS = [ @@ -52,6 +53,7 @@ V1_SUPPORTED_MODELS = [ "hmellor/tiny-random-BambaForCausalLM", "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", + "LiquidAI/LFM2-1.2B", ] FULL_CUDA_GRAPH_MODELS = [ @@ -59,6 +61,10 @@ FULL_CUDA_GRAPH_MODELS = [ "Zyphra/Zamba2-1.2B-instruct", ] +V0_UNSUPPORTED_MODELS = [ + "LiquidAI/LFM2-1.2B", +] + # Avoid OOM MAX_NUM_SEQS = 4 @@ -94,9 +100,12 @@ def test_models( else: hf_outputs = None - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + if model not in V0_UNSUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + else: + vllm_v0_outputs = None if model in V1_SUPPORTED_MODELS: with monkeypatch.context() as m: @@ -112,7 +121,7 @@ def test_models( else: vllm_v1_outputs = None - if hf_outputs is not None: + if hf_outputs is not None and vllm_v0_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v0_outputs, @@ -122,6 +131,7 @@ def test_models( if model in V1_SUPPORTED_MODELS: ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + assert ref_outputs is not None check_logprobs_close( outputs_0_lst=ref_outputs, outputs_1_lst=vllm_v1_outputs, @@ -140,6 +150,9 @@ def test_batching( max_tokens: int, num_logprobs: int, ) -> None: + if model in V0_UNSUPPORTED_MODELS: + pytest.skip( + f"Unsupported V0 Engine. Skipping `test_batching` on {model}.") try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) @@ -392,9 +405,12 @@ def test_full_cuda_graph( else: hf_outputs = None - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + if model not in V0_UNSUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + else: + vllm_v0_outputs = None with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -408,7 +424,7 @@ def test_full_cuda_graph( vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - if hf_outputs is not None: + if hf_outputs is not None and vllm_v0_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v0_outputs, @@ -417,6 +433,7 @@ def test_full_cuda_graph( ) ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + assert ref_outputs is not None check_logprobs_close( outputs_0_lst=ref_outputs, outputs_1_lst=vllm_v1_outputs, diff --git a/tests/models/registry.py b/tests/models/registry.py index 4f69f90b6aae..a6d5c305f799 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -230,6 +230,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "tiny": "ai21labs/Jamba-tiny-dev", "random": "ai21labs/Jamba-tiny-random", # noqa: E501 }), + "Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B", + min_transformers_version="4.54"), "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index f06b34285eae..bbd3da982af8 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -95,6 +95,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): + if model_arch == "Lfm2ForCausalLM": + pytest.skip("Skipping until test supports V1-only models") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c654485f4fe9..e2785e7602e4 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -337,6 +337,7 @@ class CompilationConfig: "vllm.unified_attention_with_output", "vllm.mamba_mixer2", "vllm.mamba_mixer", + "vllm.short_conv", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 66674d1a6f25..280a9e45e662 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -54,6 +54,16 @@ class MambaStateDtypeCalculator: return (conv_state_dtype, temporal_state_dtype) + @classmethod + def short_conv_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, + model_dtype) + return (conv_state_dtype, ) + class MambaStateShapeCalculator: @@ -122,6 +132,20 @@ class MambaStateShapeCalculator: tp_world_size), head_dim, state_size) return conv_state_shape, temporal_state_shape + @classmethod + def short_conv_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + conv_kernel: int, + use_v1: bool = True, + ) -> tuple[tuple[int, int]]: + conv_dim = divide(intermediate_size, tp_world_size) + conv_state_shape = (conv_kernel - 1, conv_dim) + if not use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + return (conv_state_shape, ) + @classmethod def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): """Compute the increase in group numbers to account for diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py new file mode 100644 index 000000000000..fead1e73e345 --- /dev/null +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -0,0 +1,262 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionMetadata) + + +@CustomOp.register("short_conv") +class ShortConv(MambaBase, CustomOp): + + def __init__(self, + config, + dim: int, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.conv_dim = dim + self.L_cache = config.conv_L_cache + self.bias = config.conv_bias + + self.conv = ColumnParallelLinear( + input_size=self.L_cache, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.conv1d", + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv.weight.data = self.conv.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[dim] * 3, + bias=self.bias, + prefix=f"{prefix}.in_proj", + ) + self.out_proj = RowParallelLinear( + input_size=dim, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.out_proj", + ) + + assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1") + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + self.kv_cache = [(torch.tensor([]), )] + + self.model_config = model_config + self.cache_config = cache_config + self.prefix = prefix + + def forward_native( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + return + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + torch.ops.vllm.short_conv( + hidden_states, + output, + self.prefix, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + forward_context = get_forward_context() + # ShortConvAttentionMetadata contains metadata necessary for the + # short_conv triton kernels to operate in continuous batching and in + # chunked prefill modes; they are computed at top-level model forward + # since they stay the same and reused for all mamba layers in the same + # iteration. + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + conv_metadata = attn_metadata + assert isinstance(attn_metadata, ShortConvAttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states + + BCx, _ = self.in_proj(hidden_states) + + B, C, x = BCx.chunk(3, dim=-1) + + conv_weights = self.conv.weight.view(self.conv.weight.size(0), + self.conv.weight.size(2)) + + if attn_metadata is None: + # V1 profile run + Bx = (B * x).contiguous() + hidden_states = C * Bx + contextualized_states, _ = self.out_proj(hidden_states) + return contextualized_states + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + num_actual_tokens = num_decodes + num_prefill_tokens + + # NOTE: V1 puts decode before prefill + # Separate prefill and decode by splitting varlen input + # Split along token dimension + B_d, B_p = torch.split( + B[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + C_d, C_p = torch.split( + C[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + x_d, x_p = torch.split( + x[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + + conv_output_list = [] + + if has_prefill: + Bx_p = (B_p * x_p).transpose(0, 1) + if conv_metadata.cu_seqlen is None: + conv_metadata = update_metadata(Bx_p, query_start_loc_p, + conv_metadata) + Bx = causal_conv1d_fn(Bx_p, + conv_weights, + self.conv.bias, + activation=None, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + metadata=conv_metadata, + query_start_loc=query_start_loc_p).transpose( + 0, 1)[:num_prefill_tokens] + + y = C_p * Bx + conv_output_list.append(y) + + if has_decode: + Bx_d = (B_d * x_d).contiguous() + Bx = causal_conv1d_update( + Bx_d, + conv_state, + conv_weights, + self.conv.bias, + activation=None, + conv_state_indices=state_indices_tensor_d) + y = C_d * Bx + conv_output_list.insert(0, y) + + # Merge prefill and decode outputs before passing to gated MLP + hidden_states = torch.vstack(conv_output_list) + + # Final linear projection + output[:num_actual_tokens], _ = self.out_proj(hidden_states) + + def get_state_dtype(self) -> tuple[torch.dtype, ...]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.short_conv_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, ...]]: + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=get_tensor_model_parallel_world_size(), + intermediate_size=self.conv_dim, + conv_kernel=self.L_cache, + ) + + @property + def mamba_type(self) -> str: + return "short_conv" + + +def short_conv( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + conv_metadata=None) + + +def short_conv_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="short_conv", + op_func=short_conv, + mutates_args=["output"], + fake_impl=short_conv_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py new file mode 100644 index 000000000000..5f3148b47ead --- /dev/null +++ b/vllm/model_executor/models/lfm2.py @@ -0,0 +1,557 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Any, Optional + +import torch +import torch.nn as nn +from transformers import Lfm2Config + +from vllm import envs +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class Lfm2MLP(nn.Module): + + def __init__( + self, + dim: int, + ff_dim: int, + multiple_of: int, + auto_adjust_ff_dim: bool, + ffn_dim_multiplier: Optional[float], + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + if auto_adjust_ff_dim: + ff_dim = int(2 * ff_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + ff_dim = int(ffn_dim_multiplier * ff_dim) + ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) + + self.w1 = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[ff_dim] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.w2 = RowParallelLinear( + input_size=ff_dim, + output_size=dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.w1(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class Lfm2Attention(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = hidden_size + self.num_kv_heads = num_kv_heads + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + n_tokens, _ = hidden_states.shape + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous() + k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous() + q = self.q_layernorm(q) + k = self.k_layernorm(k) + q, k = self.rotary_emb(positions, q, k) + q = q.view(n_tokens, self.num_heads * self.head_dim) + k = k.view(n_tokens, self.num_kv_heads * self.head_dim) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class Lfm2AttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.prefix = prefix + self.config = config + self.layer_idx = layer_idx + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + + self.self_attn = Lfm2Attention( + config=config, + layer_idx=layer_idx, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + self.feed_forward = Lfm2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.ffn_norm(hidden_states, residual) + return self.feed_forward(hidden_states), residual + + +class Lfm2ShortConvDecoderLayer(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.conv = ShortConv( + config=config, + dim=config.conv_dim, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.conv", + ) + + self.feed_forward = Lfm2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + output = torch.empty_like(hidden_states) + self.conv( + hidden_states, + output, + conv_metadata=None, + ) + hidden_states, residual = self.ffn_norm(output, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Lfm2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + + def get_layer(prefix: str): + layer_idx = extract_layer_index(prefix) + is_attn = self.config.layer_types[layer_idx] == "full_attention" + layer_class = (Lfm2AttentionDecoderLayer + if is_attn else Lfm2ShortConvDecoderLayer) + return layer_class( + config, + layer_idx, + model_config, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.embedding_norm = RMSNorm(config.hidden_size, + eps=config.norm_eps) + else: + self.embedding_norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.embedding_norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".w1", ".w1", 0), + (".w1", ".w3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "w1": [ + "w1", + "w3", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, ...]: + + return MambaStateDtypeCalculator.short_conv_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int]]: + """ Calculate shapes for LFM2's convolutional cache. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.conv_dim, + conv_kernel=hf_config.conv_L_cache, + use_v1=use_v1, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert (not cache_config.enable_prefix_caching + ), "Lfm2 currently does not support prefix caching" + assert envs.VLLM_USE_V1, ( + "Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1") + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + + self.model = Lfm2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = self.config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 39a3e425a46d..28d7e93af91a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -93,6 +93,7 @@ _TEXT_GENERATION_MODELS = { "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501 # For decapoda-research/llama-* diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py index d3a0c63c5e96..fb1844508211 100644 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ b/vllm/v1/attention/backends/mamba_selectors.py @@ -4,6 +4,8 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: @@ -13,6 +15,8 @@ def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: return Mamba2AttentionBackend if mamba_type == "linear_attention": return LinearAttentionBackend + if mamba_type == "short_conv": + return ShortConvAttentionBackend raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " "supported yet.") diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py new file mode 100644 index 000000000000..d80ced8ec876 --- /dev/null +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class ShortConvAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: + return ShortConvAttentionMetadataBuilder + + +@dataclass +class ShortConvAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + + query_start_loc: torch.Tensor + has_initial_states: torch.Tensor + state_indices_tensor: torch.Tensor # shape: [batch,] + + # For causal_conv1d + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.tensor] = None + token_chunk_offset_ptr: Optional[torch.tensor] = None + + +class ShortConvAttentionMetadataBuilder( + AttentionMetadataBuilder[ShortConvAttentionMetadata]): + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> ShortConvAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=1)) + has_initial_states = None + if num_prefills > 0: + #[batch,] + has_initial_states_cpu = ( + common_attn_metadata. + num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) + has_initial_states = has_initial_states_cpu.to( + query_start_loc.device) + + attn_metadata = ShortConvAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + query_start_loc=query_start_loc, + has_initial_states=has_initial_states, + state_indices_tensor=state_indices_tensor, + ) + return attn_metadata \ No newline at end of file