mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Model] Refactor JambaForCausalLM (#21394)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
37efc63b64
commit
61a6905ab0
@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -33,7 +34,7 @@ from vllm.utils import LayerBlockType
|
||||
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
SupportsV0Only)
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -87,23 +88,6 @@ class JambaMoE(nn.Module):
|
||||
return hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class JambaMLP(JambaMoE):
|
||||
|
||||
def __init__(self,
|
||||
config: JambaConfig,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(config,
|
||||
num_experts=1,
|
||||
top_k=1,
|
||||
params_dtype=params_dtype,
|
||||
tp_size=tp_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
|
||||
class JambaMambaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -132,10 +116,20 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||
self.feed_forward = ffn_layer_class(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward")
|
||||
if num_experts > 1:
|
||||
self.feed_forward = JambaMoE(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
else:
|
||||
self.feed_forward = JambaMLP(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -216,10 +210,20 @@ class JambaAttentionDecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||
self.feed_forward = ffn_layer_class(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward")
|
||||
if num_experts > 1:
|
||||
self.feed_forward = JambaMoE(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
else:
|
||||
self.feed_forward = JambaMLP(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -359,15 +363,97 @@ class JambaModel(nn.Module):
|
||||
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if 'experts' in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for (
|
||||
param_name,
|
||||
weight_name,
|
||||
expert_id,
|
||||
shard_id,
|
||||
) in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid, SupportsV0Only):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
|
||||
".self_attn.": ".",
|
||||
".A_log": ".A"
|
||||
}, )
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"in_proj": ["in_proj"],
|
||||
}
|
||||
|
||||
@ -468,96 +554,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "A_log" in name:
|
||||
name = name.replace("A_log", "A")
|
||||
|
||||
if ".self_attn." in name:
|
||||
name = name.replace(".self_attn", "")
|
||||
|
||||
if "feed_forward" in name and not _is_moe_layer(name):
|
||||
## map MLP layers to expert with ID=0
|
||||
name = name.replace("feed_forward", "feed_forward.experts.0")
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if 'experts' in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for (
|
||||
param_name,
|
||||
weight_name,
|
||||
expert_id,
|
||||
shard_id,
|
||||
) in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
def _is_moe_layer(name: str):
|
||||
return any(
|
||||
[experts_name in name for experts_name in [
|
||||
"experts",
|
||||
"router",
|
||||
]])
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
|
||||
class JambaForSequenceClassification(JambaForCausalLM):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user