mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-18 03:07:03 +08:00
[Model] Optimize nemotron_h implementation (#19249)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
f168b85725
commit
7661e92ef8
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
@ -29,7 +30,7 @@ from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -63,19 +64,22 @@ class NemotronHMLP(nn.Module):
|
||||
config: NemotronHConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.up_proj = MergedColumnParallelLinear(
|
||||
self.up_proj = ColumnParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_sizes=[config.intermediate_size],
|
||||
output_size=config.intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
self.act_fn = ReLUSquaredActivation()
|
||||
|
||||
@ -99,9 +103,12 @@ class NemotronHMLPDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.mixer = NemotronHMLP(config,
|
||||
quant_config=quant_config,
|
||||
bias=config.mlp_bias)
|
||||
self.mixer = NemotronHMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
bias=config.mlp_bias,
|
||||
prefix=f"{prefix}.mixer",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -207,12 +214,14 @@ class NemotronHAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
@ -253,7 +262,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix,
|
||||
prefix=f"{prefix}.mixer",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -435,7 +444,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["up_proj", "down_proj"]
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user