From 1e44ffc3ff5be0f7bd4c4e7efa888a80d2681743 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 10 Apr 2025 09:19:42 +0800 Subject: [PATCH] Add GLM-4-0414 support (#16338) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: lvfei.lv Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: DarkLight1337 Signed-off-by: yihong0618 Signed-off-by: Lu Fang Signed-off-by: Ajay Vohra Signed-off-by: NickLucche Signed-off-by: Guillaume Calmettes Co-authored-by: Accelerator1996 Co-authored-by: Cyrus Leung Co-authored-by: Michael Goin Co-authored-by: yihong Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: ajayvohra2005 Co-authored-by: Nicolò Lucchesi Co-authored-by: Guillaume Calmettes --- docs/source/models/supported_models.md | 5 + tests/models/registry.py | 5 + vllm/model_executor/models/glm4.py | 313 +++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 4 files changed, 324 insertions(+) create mode 100644 vllm/model_executor/models/glm4.py diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 94a9b039a61d..2ebec2ea968a 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -303,6 +303,11 @@ See [this page](#generative-models) for more information on how to use generativ * `THUDM/glm-4-9b-chat-hf`, etc. * ✅︎ * ✅︎ +- * `Glm4ForCausalLM` + * GLM-4-0414 + * `THUDM/GLM-4-32B-Chat-0414`, etc. + * ✅︎ + * ✅︎ - * `GPT2LMHeadModel` * GPT-2 * `gpt2`, `gpt2-xl`, etc. diff --git a/tests/models/registry.py b/tests/models/registry.py index 73b7c0fa9745..40479fb8a5b0 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -146,6 +146,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it", min_transformers_version="4.50"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), + "Glm4ForCausalLM": _HfExamplesInfo( + "THUDM/GLM-4-32B-Chat-0414", + is_available_online=False, + min_transformers_version="4.52.dev0" + ), "GPT2LMHeadModel": _HfExamplesInfo("gpt2"), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), "GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"), diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py new file mode 100644 index 000000000000..cba093cbfef7 --- /dev/null +++ b/vllm/model_executor/models/glm4.py @@ -0,0 +1,313 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The Zhipu AI team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Glm4Config + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .llama import LlamaMLP as Glm4MLP +from .llama import LlamaModel +from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix + + +class Glm4Attention(nn.Module): + + def __init__(self, + config: Glm4Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.rotary_dim = int(partial_rotary_factor * self.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Glm4DecoderLayer(nn.Module): + + def __init__( + self, + config: Glm4Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + self.self_attn = Glm4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=AttentionType.DECODER, + ) + self.mlp = Glm4MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_self_attn_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mlp_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + hidden_states = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": Glm4DecoderLayer, +} + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class Glm4Model(LlamaModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=Glm4DecoderLayer) + + +class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Glm4Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 0de24b578c17..6a70f6bb7236 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -58,6 +58,7 @@ _TEXT_GENERATION_MODELS = { "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), + "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),