# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Copyright 2025 The Swiss AI Initiative. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate the architectural differences made by # the Swiss AI Initiative that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Apertus model compatible with HuggingFace weights.""" from collections.abc import Iterable from itertools import islice from typing import Any import torch from torch import nn from transformers import ApertusConfig from vllm.attention import Attention, AttentionType from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import XIELU from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) class ApertusMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: QuantizationConfig | None = None, bias: bool = False, prefix: str = "", reduce_results: bool = True, ) -> None: super().__init__() self.up_proj = ColumnParallelLinear( input_size=hidden_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.up_proj", ) self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=hidden_size, bias=bias, quant_config=quant_config, reduce_results=reduce_results, prefix=f"{prefix}.down_proj", ) if hidden_act != "xielu": raise ValueError( f"Unsupported activation: {hidden_act}. " "Only xIELU is supported for now." ) self.act_fn = XIELU() def forward(self, x): x, _ = self.up_proj(x) x = self.act_fn(x) x, _ = self.down_proj(x) return x class ApertusAttention(nn.Module): def __init__( self, config: ApertusConfig, hidden_size: int, num_heads: int, num_kv_heads: int, rope_theta: float = 10000, rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, quant_config: QuantizationConfig | None = None, bias: bool = False, bias_o_proj: bool = False, cache_config: CacheConfig | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: super().__init__() layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo head_dim = getattr(config, "head_dim", None) if head_dim is None: head_dim = self.hidden_size // self.total_num_heads self.head_dim = head_dim # Phi models introduced a partial_rotary_factor parameter in the config self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias_o_proj, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self._init_rotary_emb( config, rope_scaling=rope_scaling, quant_config=quant_config ) sliding_window = None if layer_types := getattr(config, "layer_types", None): is_sliding = layer_types[layer_idx] == "sliding_attention" if is_sliding: sliding_window = config.sliding_window attn_cls = ( EncoderOnlyAttention if attn_type == AttentionType.ENCODER_ONLY else Attention ) self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", ) self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = self.q_norm(q.contiguous().view(-1, self.head_dim)).view_as(q) k = self.k_norm(k.contiguous().view(-1, self.head_dim)).view_as(k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output def _init_rotary_emb( self, config: ApertusConfig, rope_scaling: dict[str, Any] | None, quant_config: QuantizationConfig | None, ) -> None: is_neox_style = True is_gguf = quant_config and quant_config.get_name() == "gguf" if is_gguf and config.model_type == "apertus": is_neox_style = False self.rotary_emb = get_rope( self.head_dim, rotary_dim=int(self.partial_rotary_factor * self.head_dim), max_position=self.max_position_embeddings, base=self.rope_theta, rope_scaling=rope_scaling, is_neox_style=is_neox_style, partial_rotary_factor=self.partial_rotary_factor, ) class ApertusDecoderLayer(nn.Module): def __init__( self, config: ApertusConfig, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): rope_scaling["original_max_position_embeddings"] = ( config.original_max_position_embeddings ) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False ) bias_o_proj = attention_bias # support internlm/internlm3-8b with qkv_bias if hasattr(config, "qkv_bias"): attention_bias = config.qkv_bias # Apertus defaults to causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable # bidirectional attention, which is used in some embedding models # (e.g. parasail-ai/GritLM-7B-vllm) if getattr(config, "is_causal", True): attn_type = AttentionType.DECODER else: attn_type = AttentionType.ENCODER_ONLY self.self_attn = ApertusAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr( config, "num_key_value_heads", config.num_attention_heads ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, bias_o_proj=bias_o_proj, cache_config=cache_config, prefix=f"{prefix}.self_attn", attn_type=attn_type, ) self.mlp = ApertusMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) self.attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.feedforward_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.attention_layernorm(hidden_states) else: hidden_states, residual = self.attention_layernorm(hidden_states, residual) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) # Fully Connected hidden_states, residual = self.feedforward_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class ApertusModel(nn.Module): def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", layer_type: type[nn.Module] = ApertusDecoderLayer, ): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config lora_vocab = ( (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: layer_type( config=config, cache_config=cache_config, quant_config=quant_config, prefix=prefix, ), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() self.aux_hidden_state_layers = tuple[int, ...]() self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] aux_hidden_states = [] for idx, layer in enumerate( islice(self.layers, self.start_layer, self.end_layer) ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) if len(aux_hidden_states) > 0: return hidden_states, aux_hidden_states return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), ] params_dict = dict(self.named_parameters()) # we need to load the buffers for beta and eps (XIELU) for name, buffer in self.named_buffers(): if name.endswith(".beta") or name.endswith(".eps"): params_dict[name] = buffer loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue if self.quant_config is not None and ( scale_name := self.quant_config.get_cache_scale(name) ): # Loading kv cache quantization scales param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = ( loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue if "scale" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} # LoRA specific attributes embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", layer_type: type[nn.Module] = ApertusDecoderLayer, ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config self.model = self._init_model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"), layer_type=layer_type, ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=( DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( self.unpadded_vocab_size, config.vocab_size, logit_scale ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) def _init_model( self, vllm_config: VllmConfig, prefix: str = "", layer_type: type[nn.Module] = ApertusDecoderLayer, ): return ApertusModel( vllm_config=vllm_config, prefix=prefix, layer_type=layer_type ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return model_output def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights)