# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable import torch import torch.distributed as dist from torch import nn from transformers import GptOssConfig from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import ( get_dp_group, get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.utils import rocm_unquantized_gemm from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, WeightsMapper, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) class OAIAttention(nn.Module): def __init__( self, config: GptOssConfig, quant_config: QuantizationConfig | None = None, cache_config: CacheConfig | None = None, prefix: str = "", ): super().__init__() self.layer_idx = extract_layer_index(prefix) self.head_dim = config.head_dim self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.hidden_size = config.hidden_size self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=config.max_position_embeddings, base=config.rope_theta, dtype=torch.float32, rope_scaling={ "rope_type": "yarn", "factor": config.rope_scaling["factor"], "original_max_position_embeddings": config.rope_scaling[ "original_max_position_embeddings" ], "beta_fast": config.rope_scaling["beta_fast"], "beta_slow": config.rope_scaling["beta_slow"], }, is_neox_style=True, ) tp_size = get_tensor_model_parallel_world_size() self.sinks = torch.nn.Parameter( torch.empty(config.num_attention_heads // tp_size, requires_grad=False) ) self.q_size = self.num_attention_heads * self.head_dim // tp_size self.kv_size = self.num_key_value_heads * self.head_dim // tp_size self.scaling = self.head_dim**-0.5 self.rope_theta = config.rope_theta self.qkv = QKVParallelLinear( hidden_size=self.hidden_size, head_size=self.head_dim, total_num_heads=self.num_attention_heads, total_num_kv_heads=self.num_key_value_heads, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( input_size=self.num_attention_heads * self.head_dim, output_size=self.hidden_size, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.num_local_attention_heads = config.num_attention_heads // tp_size self.num_local_key_value_heads = config.num_key_value_heads // tp_size # Only apply sliding window to every other layer sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None self.attn = Attention( self.num_local_attention_heads, self.head_dim, self.scaling, num_kv_heads=self.num_local_key_value_heads, cache_config=cache_config, quant_config=quant_config, per_layer_sliding_window=sliding_window, attn_type=AttentionType.DECODER, prefix=f"{prefix}.attn", sinks=self.sinks, ) def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor ) -> torch.Tensor: qkv, _ = self.qkv(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) v = v.contiguous() attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class MLPBlock(torch.nn.Module): def __init__( self, vllm_config: VllmConfig, layer_idx: int, prefix: str = "", ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe self.layer_idx = layer_idx self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts) assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, reduce_results=True, renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", apply_router_weight_on_input=False, has_bias=True, activation="swigluoai", is_sequence_parallel=self.is_sequence_parallel, ) def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens = x.shape[0] if self.is_sequence_parallel: x = sequence_parallel_chunk(x) if current_platform.is_rocm(): g = rocm_unquantized_gemm( self, x[:, : self.hidden_size], self.router.weight, self.router.bias ) else: g = self.router(x) x = self.experts(hidden_states=x, router_logits=g) if self.is_sequence_parallel: x = tensor_model_parallel_all_gather(x.contiguous(), 0) x = x[:num_tokens] return x class TransformerBlock(torch.nn.Module): def __init__( self, vllm_config: VllmConfig, prefix: str = "", ): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config self.layer_idx = extract_layer_index(prefix) self.attn = OAIAttention( config, prefix=f"{prefix}.attn", cache_config=cache_config ) self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp") self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor, residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.attn(hidden_states, positions) # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) output = self.mlp(hidden_states) return output, residual @support_torch_compile class GptOssModel(nn.Module): def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", ): super().__init__() self.config = vllm_config.model_config.hf_config self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, ) self.start_layer, self.end_layer, self.layers = make_layers( self.config.num_hidden_layers, lambda prefix: TransformerBlock( vllm_config, prefix=prefix, ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], self.config.hidden_size ) self.aux_hidden_state_layers = tuple[int, ...]() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embedding(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: x = inputs_embeds else: x = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None x = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] aux_hidden_states = [] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] if i in self.aux_hidden_state_layers: aux_hidden_states.append(x if residual is None else x + residual) x, residual = layer(x, positions, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": x, "residual": residual}) x, _ = self.norm(x, residual) if len(aux_hidden_states) > 0: return x, aux_hidden_states return x def _load_weights_mxfp4( self, ep_rank_end: int, ep_rank_start: int, heads_per_rank: int, head_start: int, weights: Iterable[tuple[str, torch.Tensor]], stacked_params_mapping: list[tuple[str, ...]], ) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() mxfp4_block = 32 use_ep = self.parallel_config.enable_expert_parallel num_experts = self.config.num_local_experts # In MoE, we need to flatten the tensor parallel size across the data # parallel size when EP is disabled. tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( tp_size=get_tensor_model_parallel_world_size(), dp_size=get_dp_group().world_size, dp_rank=get_dp_group().rank_in_group, ) intermediate_size = self.config.intermediate_size intermediate_size_block = intermediate_size // mxfp4_block per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) for name, weight in weights: # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue if ".w13_weight_scale" in name: # Handle MLP gate and up projection weights scale if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( param, narrow_weight, weight_name=name, shard_id=None, expert_id=None, ) loaded_params.add(name) continue elif ".w2_weight_scale" in name: # Handle MLP down projection weights if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[ ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block ] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( param, narrow_weight, weight_name=name, shard_id=None, expert_id=None, ) loaded_params.add(name) continue elif ".w13_weight" in name: # Handle MLP gate and up projection weights # flat weight from (E, 2 * N, block_size, entry_per_block) # to (E, 2 * N, -1), shouldn't trigger copy for contiguous weight = weight.view( num_experts, 2 * intermediate_size, -1 ).contiguous() # Extract gate and up projection parts # since the weight is shuffled, we can slice directly if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( param, narrow_weight, weight_name=name, shard_id=None, expert_id=None, ) loaded_params.add(name) continue elif ".w2_weight" in name: # Handle MLP down projection weights # same flatten here, but since 2 mx4 value are packed in 1 # uint8, divide by 2 weight = weight.view( num_experts, -1, intermediate_size // 2 ).contiguous() if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( param, narrow_weight, weight_name=name, shard_id=None, expert_id=None, ) loaded_params.add(name) continue elif ".w13_bias" in name: # Handle MLP gate and up projection biases # Extract gate and up projection bias parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( param, narrow_weight, weight_name=name, shard_id=None, expert_id=None, ) loaded_params.add(name) continue elif ".w2_bias" in name: # Handle MLP down projection bias param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if use_ep: weight = weight[ep_rank_start:ep_rank_end, ...] else: # (only load on rank 0 to avoid duplication) if tp_rank != 0: weight.zero_() weight_loader( param, weight, weight_name=name, shard_id=None, expert_id=None ) loaded_params.add(name) continue elif "sinks" in name: # Handle attention sinks (distributed across ranks) param = params_dict[name] narrow_weight = weight.narrow(0, head_start, heads_per_rank) param.data.copy_(narrow_weight) loaded_params.add(name) continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, weight) else: weight_loader(param, weight, shard_id) break else: # Handle all other weights with potential renaming if name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) loaded_params.add(name) return loaded_params def _load_weights_other( self, ep_rank_start: int, ep_rank_end: int, heads_per_rank: int, head_start: int, weights: Iterable[tuple[str, torch.Tensor]], stacked_params_mapping: list[tuple[str, ...]], ) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() use_ep = self.parallel_config.enable_expert_parallel # In MoE, we need to flatten the tensor parallel size across the data # parallel size when EP is disabled. tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( tp_size=get_tensor_model_parallel_world_size(), dp_size=get_dp_group().world_size, dp_rank=get_dp_group().rank_in_group, ) intermediate_size = self.config.intermediate_size per_rank_intermediate_size = cdiv(intermediate_size, tp_size) # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) for name, weight in weights: # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue if ".w13_weight" in name: # Handle MLP gate and up projection weights # Extract gate and up projection parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end] narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() param = params_dict[name] param.copy_(narrow_weight) loaded_params.add(name) continue elif ".w2_weight" in name: # Handle MLP down projection weights if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() param = params_dict[name] param.copy_(narrow_weight) loaded_params.add(name) continue elif ".w13_bias" in name: # Handle MLP gate and up projection biases # Extract gate and up projection bias parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] param = params_dict[name] param.copy_(narrow_weight) loaded_params.add(name) continue elif ".w2_bias" in name: # Handle MLP down projection bias if use_ep: weight = weight[ep_rank_start:ep_rank_end, ...] else: # (only load on rank 0 to avoid duplication) if tp_rank != 0: weight.zero_() param = params_dict[name] param.copy_(weight) loaded_params.add(name) continue elif "sinks" in name: # Handle attention sinks (distributed across ranks) param = params_dict[name] narrow_weight = weight.narrow(0, head_start, heads_per_rank) param.data.copy_(narrow_weight) loaded_params.add(name) continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, weight) else: weight_loader(param, weight, shard_id) break else: # Handle all other weights with potential renaming if name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) loaded_params.add(name) return loaded_params def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv", ".q_proj", "q"), (".qkv", ".k_proj", "k"), (".qkv", ".v_proj", "v"), ] tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() # Attention heads per rank heads_per_rank = self.config.num_attention_heads // tp_size head_start = tp_rank * heads_per_rank ep_size = get_ep_group().world_size ep_rank = get_ep_group().rank num_experts = self.config.num_local_experts experts_per_rank = num_experts // ep_size ep_rank_start = ep_rank * experts_per_rank ep_rank_end = (ep_rank + 1) * experts_per_rank quant_method = ( self.config.quantization_config["quant_method"] if hasattr(self.config, "quantization_config") else None ) if quant_method == "mxfp4": return self._load_weights_mxfp4( ep_rank_end, ep_rank_start, heads_per_rank, head_start, weights, stacked_params_mapping, ) else: return self._load_weights_other( ep_rank_end, ep_rank_start, heads_per_rank, head_start, weights, stacked_params_mapping, ) class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ ".self_attn.": ".attn.", }, orig_to_new_suffix={ ".embed_tokens.weight": ".embedding.weight", # MoE MXFP4 weights ".gate_up_proj_blocks": ".w13_weight", ".down_proj_blocks": ".w2_weight", ".gate_up_proj_scales": ".w13_weight_scale", ".down_proj_scales": ".w2_weight_scale", # MoE other weights ".gate_up_proj": ".w13_weight", ".down_proj": ".w2_weight", # MoE Bias ".gate_up_proj_bias": ".w13_bias", ".down_proj_bias": ".w2_bias", }, ) def __init__( self, vllm_config: VllmConfig, prefix: str = "", ): super().__init__() self.vllm_config = vllm_config self.config = vllm_config.model_config.hf_config self.model = GptOssModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"), ) self.lm_head = ParallelLMHead( self.config.vocab_size, self.config.hidden_size, prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(self.config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, weight scales, activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_local_experts, num_redundant_experts=0, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)