mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 22:17:40 +08:00
[Bugfix] Clean up some cruft in mamba.py (#9343)
This commit is contained in:
parent
f0fe4fe86d
commit
169b530607
@ -155,7 +155,7 @@ Text Generation
|
||||
* - :code:`MambaForCausalLM`
|
||||
- Mamba
|
||||
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
|
||||
- ✅︎
|
||||
-
|
||||
-
|
||||
* - :code:`MiniCPMForCausalLM`
|
||||
- MiniCPM
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# coding=utf-8
|
||||
"""PyTorch MAMBA model."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -10,7 +9,6 @@ from transformers import MambaConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
@ -39,13 +37,6 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaCacheParams:
|
||||
is_prompt: bool = False
|
||||
conv_state: torch.Tensor = torch.Tensor()
|
||||
ssm_state: torch.Tensor = torch.Tensor()
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
class MambaMixer(nn.Module):
|
||||
"""
|
||||
@ -209,37 +200,6 @@ class MambaMixer(nn.Module):
|
||||
return contextualized_states
|
||||
|
||||
|
||||
class MambaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MambaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
hidden_act = config.hidden_act
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MambaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -252,7 +212,6 @@ class MambaDecoderLayer(nn.Module):
|
||||
self.config = config
|
||||
self.mixer = MambaMixer(config, layer_idx)
|
||||
|
||||
self.feed_forward = MambaMLP(config, quant_config=quant_config)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
@ -274,10 +233,6 @@ class MambaDecoderLayer(nn.Module):
|
||||
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
|
||||
ssm_state)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@ -319,7 +274,6 @@ class MambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
@ -346,26 +300,6 @@ class MambaModel(nn.Module):
|
||||
|
||||
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embeddings": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -416,8 +350,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
|
||||
input_ids, attn_metadata, **kwargs)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_tensors[0],
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1])
|
||||
|
||||
return hidden_states
|
||||
@ -457,43 +391,16 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "A_log" in name:
|
||||
name = name.replace("A_log", "A")
|
||||
|
||||
if ".self_attn." in name:
|
||||
name = name.replace(".self_attn", "")
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user