[Model] Optimize nemotron_h implementation (#19249)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-06-06 18:05:14 +08:00 committed by GitHub
parent f168b85725
commit 7661e92ef8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
@ -29,7 +30,7 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -63,19 +64,22 @@ class NemotronHMLP(nn.Module):
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.up_proj = MergedColumnParallelLinear(
self.up_proj = ColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size],
output_size=config.intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = ReLUSquaredActivation()
@ -99,9 +103,12 @@ class NemotronHMLPDecoderLayer(nn.Module):
super().__init__()
self.config = config
self.mixer = NemotronHMLP(config,
quant_config=quant_config,
bias=config.mlp_bias)
self.mixer = NemotronHMLP(
config,
quant_config=quant_config,
bias=config.mlp_bias,
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -207,12 +214,14 @@ class NemotronHAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = Attention(
@ -253,7 +262,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
layer_idx,
cache_config,
quant_config,
prefix,
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -435,7 +444,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"]
}
# LoRA specific attributes