[Bugfix] Clean up some cruft in mamba.py (#9343)

This commit is contained in:
Tyler Michael Smith 2024-10-14 20:24:25 -04:00 committed by GitHub
parent f0fe4fe86d
commit 169b530607
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 103 deletions

View File

@ -155,7 +155,7 @@ Text Generation
* - :code:`MambaForCausalLM`
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
- ✅︎
-
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM

View File

@ -1,6 +1,5 @@
# coding=utf-8
"""PyTorch MAMBA model."""
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple
import torch
@ -10,7 +9,6 @@ from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
@ -39,13 +37,6 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
KVCache = Tuple[torch.Tensor, torch.Tensor]
@dataclass
class MambaCacheParams:
is_prompt: bool = False
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class MambaMixer(nn.Module):
"""
@ -209,37 +200,6 @@ class MambaMixer(nn.Module):
return contextualized_states
class MambaMLP(nn.Module):
def __init__(
self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
hidden_act = config.hidden_act
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MambaDecoderLayer(nn.Module):
def __init__(self,
@ -252,7 +212,6 @@ class MambaDecoderLayer(nn.Module):
self.config = config
self.mixer = MambaMixer(config, layer_idx)
self.feed_forward = MambaMLP(config, quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
@ -274,10 +233,6 @@ class MambaDecoderLayer(nn.Module):
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
ssm_state)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
@ -319,7 +274,6 @@ class MambaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
@ -346,26 +300,6 @@ class MambaModel(nn.Module):
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embeddings": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
@ -416,8 +350,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)
hidden_states = self.backbone(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_tensors[0],
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_tensors[0],
mamba_cache_tensors[1])
return hidden_states
@ -457,43 +391,16 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "A_log" in name:
name = name.replace("A_log", "A")
if ".self_attn." in name:
name = name.replace(".self_attn", "")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)