[Model] Add LFM2 architecture (#22845)

Signed-off-by: Paul Pak <paulpak58@gmail.com>
This commit is contained in:
Paul Pak 2025-08-21 01:35:07 -06:00 committed by GitHub
parent 31282401b6
commit 2e2000f352
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 960 additions and 8 deletions

View File

@ -373,6 +373,7 @@ th {
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |

View File

@ -31,6 +31,7 @@ HYBRID_MODELS = [
"hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
]
HF_UNSUPPORTED_MODELS = [
@ -52,6 +53,7 @@ V1_SUPPORTED_MODELS = [
"hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
]
FULL_CUDA_GRAPH_MODELS = [
@ -59,6 +61,10 @@ FULL_CUDA_GRAPH_MODELS = [
"Zyphra/Zamba2-1.2B-instruct",
]
V0_UNSUPPORTED_MODELS = [
"LiquidAI/LFM2-1.2B",
]
# Avoid OOM
MAX_NUM_SEQS = 4
@ -94,9 +100,12 @@ def test_models(
else:
hf_outputs = None
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None
if model in V1_SUPPORTED_MODELS:
with monkeypatch.context() as m:
@ -112,7 +121,7 @@ def test_models(
else:
vllm_v1_outputs = None
if hf_outputs is not None:
if hf_outputs is not None and vllm_v0_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
@ -122,6 +131,7 @@ def test_models(
if model in V1_SUPPORTED_MODELS:
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
assert ref_outputs is not None
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_1_lst=vllm_v1_outputs,
@ -140,6 +150,9 @@ def test_batching(
max_tokens: int,
num_logprobs: int,
) -> None:
if model in V0_UNSUPPORTED_MODELS:
pytest.skip(
f"Unsupported V0 Engine. Skipping `test_batching` on {model}.")
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
@ -392,9 +405,12 @@ def test_full_cuda_graph(
else:
hf_outputs = None
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@ -408,7 +424,7 @@ def test_full_cuda_graph(
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if hf_outputs is not None:
if hf_outputs is not None and vllm_v0_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
@ -417,6 +433,7 @@ def test_full_cuda_graph(
)
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
assert ref_outputs is not None
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_1_lst=vllm_v1_outputs,

View File

@ -230,6 +230,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random", # noqa: E501
}),
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B",
min_transformers_version="4.54"),
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501

View File

@ -95,6 +95,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
if model_arch == "Lfm2ForCausalLM":
pytest.skip("Skipping until test supports V1-only models")
can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)

View File

@ -337,6 +337,7 @@ class CompilationConfig:
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
]
def compute_hash(self) -> str:

View File

@ -54,6 +54,16 @@ class MambaStateDtypeCalculator:
return (conv_state_dtype, temporal_state_dtype)
@classmethod
def short_conv_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)
return (conv_state_dtype, )
class MambaStateShapeCalculator:
@ -122,6 +132,20 @@ class MambaStateShapeCalculator:
tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape
@classmethod
def short_conv_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return (conv_state_shape, )
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for

View File

@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionMetadata)
@CustomOp.register("short_conv")
class ShortConv(MambaBase, CustomOp):
def __init__(self,
config,
dim: int,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.conv_dim = dim
self.L_cache = config.conv_L_cache
self.bias = config.conv_bias
self.conv = ColumnParallelLinear(
input_size=self.L_cache,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv.weight.data = self.conv.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
input_size=dim,
output_sizes=[dim] * 3,
bias=self.bias,
prefix=f"{prefix}.in_proj",
)
self.out_proj = RowParallelLinear(
input_size=dim,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.out_proj",
)
assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1")
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
self.kv_cache = [(torch.tensor([]), )]
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
return
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
torch.ops.vllm.short_conv(
hidden_states,
output,
self.prefix,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
forward_context = get_forward_context()
# ShortConvAttentionMetadata contains metadata necessary for the
# short_conv triton kernels to operate in continuous batching and in
# chunked prefill modes; they are computed at top-level model forward
# since they stay the same and reused for all mamba layers in the same
# iteration.
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
BCx, _ = self.in_proj(hidden_states)
B, C, x = BCx.chunk(3, dim=-1)
conv_weights = self.conv.weight.view(self.conv.weight.size(0),
self.conv.weight.size(2))
if attn_metadata is None:
# V1 profile run
Bx = (B * x).contiguous()
hidden_states = C * Bx
contextualized_states, _ = self.out_proj(hidden_states)
return contextualized_states
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_decodes + num_prefill_tokens
# NOTE: V1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
B_d, B_p = torch.split(
B[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
C_d, C_p = torch.split(
C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
x_d, x_p = torch.split(
x[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
conv_output_list = []
if has_prefill:
Bx_p = (B_p * x_p).transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(Bx_p, query_start_loc_p,
conv_metadata)
Bx = causal_conv1d_fn(Bx_p,
conv_weights,
self.conv.bias,
activation=None,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=conv_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
y = C_p * Bx
conv_output_list.append(y)
if has_decode:
Bx_d = (B_d * x_d).contiguous()
Bx = causal_conv1d_update(
Bx_d,
conv_state,
conv_weights,
self.conv.bias,
activation=None,
conv_state_indices=state_indices_tensor_d)
y = C_d * Bx
conv_output_list.insert(0, y)
# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(conv_output_list)
# Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.short_conv_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...]]:
return MambaStateShapeCalculator.short_conv_state_shape(
tp_world_size=get_tensor_model_parallel_world_size(),
intermediate_size=self.conv_dim,
conv_kernel=self.L_cache,
)
@property
def mamba_type(self) -> str:
return "short_conv"
def short_conv(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
conv_metadata=None)
def short_conv_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="short_conv",
op_func=short_conv,
mutates_args=["output"],
fake_impl=short_conv_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -0,0 +1,557 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Any, Optional
import torch
import torch.nn as nn
from transformers import Lfm2Config
from vllm import envs
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.short_conv import ShortConv
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class Lfm2MLP(nn.Module):
def __init__(
self,
dim: int,
ff_dim: int,
multiple_of: int,
auto_adjust_ff_dim: bool,
ffn_dim_multiplier: Optional[float],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
if auto_adjust_ff_dim:
ff_dim = int(2 * ff_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
ff_dim = int(ffn_dim_multiplier * ff_dim)
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
self.w1 = MergedColumnParallelLinear(
input_size=dim,
output_sizes=[ff_dim] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.w2 = RowParallelLinear(
input_size=ff_dim,
output_size=dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.w1(x)
x = self.act_fn(gate_up)
x, _ = self.w2(x)
return x
class Lfm2Attention(nn.Module):
def __init__(
self,
config: Lfm2Config,
layer_idx: int,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
self.num_kv_heads = num_kv_heads
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn",
)
self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps)
self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
n_tokens, _ = hidden_states.shape
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous()
k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous()
q = self.q_layernorm(q)
k = self.k_layernorm(k)
q, k = self.rotary_emb(positions, q, k)
q = q.view(n_tokens, self.num_heads * self.head_dim)
k = k.view(n_tokens, self.num_kv_heads * self.head_dim)
attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output)
return output
class Lfm2AttentionDecoderLayer(nn.Module):
def __init__(
self,
config: Lfm2Config,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.prefix = prefix
self.config = config
self.layer_idx = layer_idx
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Lfm2Attention(
config=config,
layer_idx=layer_idx,
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.feed_forward = Lfm2MLP(
dim=config.block_dim,
ff_dim=config.block_ff_dim,
multiple_of=config.block_multiple_of,
auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
ffn_dim_multiplier=config.block_ffn_dim_multiplier,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.operator_norm(hidden_states)
else:
hidden_states, residual = self.operator_norm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
hidden_states, residual = self.ffn_norm(hidden_states, residual)
return self.feed_forward(hidden_states), residual
class Lfm2ShortConvDecoderLayer(nn.Module):
def __init__(
self,
config: Lfm2Config,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.conv = ShortConv(
config=config,
dim=config.conv_dim,
layer_idx=layer_idx,
model_config=model_config,
cache_config=cache_config,
prefix=f"{prefix}.conv",
)
self.feed_forward = Lfm2MLP(
dim=config.block_dim,
ff_dim=config.block_ff_dim,
multiple_of=config.block_multiple_of,
auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
ffn_dim_multiplier=config.block_ffn_dim_multiplier,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.operator_norm(hidden_states)
else:
hidden_states, residual = self.operator_norm(
hidden_states, residual)
output = torch.empty_like(hidden_states)
self.conv(
hidden_states,
output,
conv_metadata=None,
)
hidden_states, residual = self.ffn_norm(output, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
@support_torch_compile
class Lfm2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size)
def get_layer(prefix: str):
layer_idx = extract_layer_index(prefix)
is_attn = self.config.layer_types[layer_idx] == "full_attention"
layer_class = (Lfm2AttentionDecoderLayer
if is_attn else Lfm2ShortConvDecoderLayer)
return layer_class(
config,
layer_idx,
model_config,
cache_config,
quant_config=quant_config,
prefix=prefix,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank:
self.embedding_norm = RMSNorm(config.hidden_size,
eps=config.norm_eps)
else:
self.embedding_norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.embedding_norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".w1", ".w1", 0),
(".w1", ".w3", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"w1": [
"w1",
"w3",
],
}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, ...]:
return MambaStateDtypeCalculator.short_conv_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int]]:
""" Calculate shapes for LFM2's convolutional cache.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
return MambaStateShapeCalculator.short_conv_state_shape(
tp_world_size=parallel_config.tensor_parallel_size,
intermediate_size=hf_config.conv_dim,
conv_kernel=hf_config.conv_L_cache,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert (not cache_config.enable_prefix_caching
), "Lfm2 currently does not support prefix caching"
assert envs.VLLM_USE_V1, (
"Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1")
super().__init__()
self.config = config
self.vllm_config = vllm_config
self.scheduler_config = scheduler_config
self.model_config = vllm_config.model_config
self.model = Lfm2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = self.config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else
lora_config.lora_vocab_padding_size),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@ -93,6 +93,7 @@ _TEXT_GENERATION_MODELS = {
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501
# For decapoda-research/llama-*

View File

@ -4,6 +4,8 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
@ -13,6 +15,8 @@ def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
return Mamba2AttentionBackend
if mamba_type == "linear_attention":
return LinearAttentionBackend
if mamba_type == "short_conv":
return ShortConvAttentionBackend
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
"supported yet.")

View File

@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class ShortConvAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
return ShortConvAttentionMetadataBuilder
@dataclass
class ShortConvAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
has_initial_states: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
# For causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
class ShortConvAttentionMetadataBuilder(
AttentionMetadataBuilder[ShortConvAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> ShortConvAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
has_initial_states = None
if num_prefills > 0:
#[batch,]
has_initial_states_cpu = (
common_attn_metadata.
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
has_initial_states = has_initial_states_cpu.to(
query_start_loc.device)
attn_metadata = ShortConvAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata