From 7661e92ef85e552936195ae4b803e292b9a96776 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 6 Jun 2025 18:05:14 +0800 Subject: [PATCH] [Model] Optimize nemotron_h implementation (#19249) Signed-off-by: Jee Jee Li --- vllm/model_executor/models/nemotron_h.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 2ef8d31150d5e..3424efa80d48f 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py # Copyright 2024 HuggingFace Inc. team. All rights reserved. @@ -29,7 +30,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -63,19 +64,22 @@ class NemotronHMLP(nn.Module): config: NemotronHConfig, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, + prefix: str = "", ) -> None: super().__init__() - self.up_proj = MergedColumnParallelLinear( + self.up_proj = ColumnParallelLinear( input_size=config.hidden_size, - output_sizes=[config.intermediate_size], + output_size=config.intermediate_size, bias=bias, quant_config=quant_config, + prefix=f"{prefix}.up_proj", ) self.down_proj = RowParallelLinear( input_size=config.intermediate_size, output_size=config.hidden_size, bias=bias, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) self.act_fn = ReLUSquaredActivation() @@ -99,9 +103,12 @@ class NemotronHMLPDecoderLayer(nn.Module): super().__init__() self.config = config - self.mixer = NemotronHMLP(config, - quant_config=quant_config, - bias=config.mlp_bias) + self.mixer = NemotronHMLP( + config, + quant_config=quant_config, + bias=config.mlp_bias, + prefix=f"{prefix}.mixer", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -207,12 +214,14 @@ class NemotronHAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.attn = Attention( @@ -253,7 +262,7 @@ class NemotronHAttentionDecoderLayer(nn.Module): layer_idx, cache_config, quant_config, - prefix, + prefix=f"{prefix}.mixer", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -435,7 +444,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, "k_proj", "v_proj", ], - "gate_up_proj": ["up_proj", "down_proj"] } # LoRA specific attributes