# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from # https://github.com/zai-org/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" import json from collections.abc import Iterable from typing import Optional, Union import torch from torch import nn from torch.nn import LayerNorm from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) class GLMAttention(nn.Module): def __init__( self, config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.multi_query_attention = config.multi_query_attention self.total_num_kv_heads = (config.multi_query_group_num if config.multi_query_attention else config.num_attention_heads) if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.query_key_value = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=config.add_bias_linear or config.add_qkv_bias, quant_config=quant_config, prefix=f"{prefix}.query_key_value", ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, quant_config=quant_config, prefix=f"{prefix}.dense", ) # https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 rope_ratio = getattr(config, "rope_ratio", 1.0) max_positions = getattr(config, "seq_length", 8192) # NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False, # which is equivalent to is_neox_style=True is_neox_style = not config.original_rope self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim // 2, max_position=max_positions, base=10000 * rope_ratio, is_neox_style=is_neox_style, ) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn") def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) context_layer = self.attn(q, k, v) attn_output, _ = self.dense(context_layer) return attn_output class GLMMLP(nn.Module): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """ def __init__( self, config: ChatGLMConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.add_bias = config.add_bias_linear # Project to 4h. self.dense_h_to_4h = MergedColumnParallelLinear( config.hidden_size, [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, quant_config=quant_config, prefix=f"{prefix}.dense_h_to_4h", ) self.activation_func = SiluAndMul() # Project back to h. self.dense_4h_to_h = RowParallelLinear( config.ffn_hidden_size, config.hidden_size, bias=config.add_bias_linear, quant_config=quant_config, prefix=f"{prefix}.dense_4h_to_h", ) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output, _ = self.dense_4h_to_h(intermediate_parallel) return output class GLMBlock(nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__( self, config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) self.fp32_residual_connection = config.fp32_residual_connection layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. self.input_layernorm = layer_norm_func(config.hidden_size, eps=config.layernorm_epsilon) # Self attention. self.self_attention = GLMAttention(config, cache_config, quant_config, prefix=f"{prefix}.self_attention") self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output self.post_attention_layernorm = layer_norm_func( config.hidden_size, eps=config.layernorm_epsilon) # MLP self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp") def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: # hidden_states: [num_tokens, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output = self.self_attention( hidden_states=layernorm_output, position_ids=position_ids, ) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states layernorm_input = residual + attention_output # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input output = self.mlp(layernorm_output) + residual return output class GLMTransformer(nn.Module): """Transformer class.""" def __init__( self, config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.post_layer_norm = config.post_layer_norm # Number of layers. self.num_layers = config.num_layers # Transformer layers. self.start_layer, self.end_layer, self.layers = make_layers( self.num_layers, lambda prefix: GLMBlock( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. self.final_layernorm = layer_norm_func( config.hidden_size, eps=config.layernorm_epsilon) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> Union[torch.Tensor, IntermediateTensors]: for layer in self.layers[self.start_layer:self.end_layer]: hidden_states = layer(hidden_states=hidden_states, position_ids=position_ids) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) # Final layer norm. if self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states @support_torch_compile class ChatGLMModel(nn.Module, SupportsQuant): packed_modules_mapping = { "linear_proj.merged_proj": ["linear_proj.gate_proj", "linear_proj.dense_h_to_4h"] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.embedding = VocabParallelEmbedding(config.padded_vocab_size, config.hidden_size, quant_config=quant_config, prefix=f"{prefix}.embedding") self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels self.encoder = GLMTransformer(config, cache_config, quant_config, prefix=f"{prefix}.encoder") self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size, quant_config=quant_config, prefix=f"{prefix}.output_layer") self.make_empty_intermediate_tensors = ( self.encoder.make_empty_intermediate_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embedding(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] # Run encoder. hidden_states = self.encoder( hidden_states=hidden_states, position_ids=positions, ) return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("linear_proj.merged_proj", "linear_proj.gate_proj", 0), ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: if "rotary_pos_emb.inv_freq" in name: continue if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class ChatGLMBaseModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={".word_embeddings": ""}, ) def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", transformer_type: type[ChatGLMModel] = ChatGLMModel, ) -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.lora_config = lora_config self.multimodal_config = multimodal_config self.quant_config = quant_config self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.transformer = transformer_type(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "transformer")) if self.config.tie_word_embeddings: self.transformer.output_layer.weight = ( self.transformer.embedding.weight) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsQuant): packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config if hasattr(config, "vision_config"): hf_overrides = {"architectures": ["GLM4VForCausalLM"]} raise RuntimeError( "The configuration of this model indicates that it supports " "vision inputs, but you instantiated the text-only version " "of this model. Please use the vision model by setting " f"`--hf-overrides '{json.dumps(hf_overrides)}'`") super().__init__(vllm_config=vllm_config, prefix=prefix) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states