From de98252f497b8cde5b9f18a8dac53302f5c72db7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 5 Aug 2025 23:26:00 -0700 Subject: [PATCH] Add GPT-OSS model code and config [1/N] (#22327) Signed-off-by: Woosuk Kwon --- tests/models/registry.py | 1 + vllm/model_executor/models/config.py | 29 ++ vllm/model_executor/models/gpt_oss.py | 472 +++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 4 files changed, 503 insertions(+) create mode 100644 vllm/model_executor/models/gpt_oss.py diff --git a/tests/models/registry.py b/tests/models/registry.py index 92a719d7a92d..69961d738518 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -197,6 +197,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { {"6b": "EleutherAI/gpt-j-6b"}), "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m", {"1b": "EleutherAI/pythia-1.4b"}), + "GptOssForCausalLM": _HfExamplesInfo("openai/gpt-oss-20b"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 6f09be7a5941..908d4e628bd6 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -247,6 +247,34 @@ class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig): config.max_model_len) +class GptOssConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + decoding_config = vllm_config.decoding_config + if decoding_config.reasoning_backend == "": + decoding_config.reasoning_backend = "openai" + + # Increase the max capture size from 512 to 1024 for performance. + # NOTE(woosuk): This will increase the number of CUDA graphs + # from 67 to 83. + scheduler_config = vllm_config.scheduler_config + if len(scheduler_config.cuda_graph_sizes) == 1: + max_capture_size = scheduler_config.cuda_graph_sizes[0] + # FIXME(woosuk): When using full cuda graph with FA3, the max + # supported size is 992. + if max_capture_size < 1024: + cuda_graph_sizes = [1, 2, 4] + # Step size 8 for small batch sizes + cuda_graph_sizes += [i for i in range(8, 256, 8)] + # Step size 16 for larger batch sizes + cuda_graph_sizes += [i for i in range(256, 1025, 16)] + scheduler_config.cuda_graph_sizes = cuda_graph_sizes + logger.info( + "Overriding max cuda graph capture size to " + "%d for performance.", 1024) + + class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): @classmethod @@ -345,4 +373,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "JinaVLForRanking": JinaVLForSequenceClassificationConfig, "JambaForSequenceClassification": JambaForSequenceClassificationConfig, "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, + "GptOssForCausalLM": GptOssConfig, } diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py new file mode 100644 index 000000000000..896560fa24ca --- /dev/null +++ b/vllm/model_executor/models/gpt_oss.py @@ -0,0 +1,472 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.distributed as dist +from torch import nn +from transformers import GptOssConfig + +from vllm import envs +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import cdiv + +from .utils import extract_layer_index, maybe_prefix + + +class OAIAttention(nn.Module): + + def __init__( + self, + config: GptOssConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.head_dim = config.head_dim + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + dtype=torch.float32, + rope_scaling={ + "rope_type": + "yarn", + "factor": + config.rope_scaling["factor"], + "original_max_position_embeddings": + config.rope_scaling["original_max_position_embeddings"], + "beta_fast": + config.rope_ntk_beta, + "beta_slow": + config.rope_ntk_alpha, + }, + is_neox_style=True, + ) + + tp_size = get_tensor_model_parallel_world_size() + + attention_sink_dtype = ( + torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION + or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16) + self.sinks = torch.nn.Parameter( + torch.empty(config.num_attention_heads // tp_size, + dtype=attention_sink_dtype, + requires_grad=False)) + + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + + self.q_size = self.num_attention_heads * self.head_dim // tp_size + self.kv_size = self.num_key_value_heads * self.head_dim // tp_size + self.scaling = self.head_dim**-0.5 + self.rope_theta = config.rope_theta + + self.qkv = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.num_attention_heads, + total_num_kv_heads=self.num_key_value_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.num_attention_heads * self.head_dim, + output_size=self.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.num_local_attention_heads = config.num_attention_heads // tp_size + self.num_local_key_value_heads = config.num_key_value_heads // tp_size + + # Only apply sliding window to every other layer + sliding_window = (config.sliding_window if self.layer_idx % + 2 == 0 else None) + self.attn = Attention( + self.num_local_attention_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_local_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=AttentionType.DECODER, + prefix=f"{prefix}.attn", + sinks=self.sinks, + ) + + def forward(self, hidden_states: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + t = self.norm(hidden_states) + + qkv, _ = self.qkv(t) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + v = v.contiguous() + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + + return output + hidden_states + + +class MLPBlock(torch.nn.Module): + + def __init__( + self, + config: GptOssConfig, + layer_idx: int, + quant_config: QuantizationConfig, + prefix: str = "", + ): + super().__init__() + self.layer_idx = layer_idx + self.num_experts = config.num_local_experts + self.experts_per_token = config.num_experts_per_tok + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + self.router = torch.nn.Linear(config.hidden_size, + config.num_local_experts, + dtype=torch.bfloat16) + assert config.intermediate_size % self.world_size == 0 + self.experts = FusedMoE(num_experts=config.num_local_experts, + top_k=config.num_experts_per_token, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + apply_router_weight_on_input=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t = self.norm(x) + g = self.router(t) + t = self.experts(hidden_states=t, router_logits=g) + return x + t + + +class TransformerBlock(torch.nn.Module): + + def __init__( + self, + config: GptOssConfig, + quant_config: QuantizationConfig, + prefix: str = "", + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.attn = OAIAttention(config, prefix=f"{prefix}.attn") + self.mlp = MLPBlock(config, + self.layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward(self, hidden_states: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + attn_output = self.attn(hidden_states, positions) + output = self.mlp(attn_output) + return output + + +@support_torch_compile +class GptOssModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.config.hidden_size = self.config.hidden_size + self.embedding = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + ) + self.layers = torch.nn.ModuleList([ + TransformerBlock( + self.config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, f"block.{layer_idx}"), + ) for layer_idx in range(self.config.num_hidden_layers) + ]) + self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + + def forward(self, input_ids: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + x = self.embedding(input_ids) + for layer in self.layers: + x = layer(x, positions) + x = self.norm(x) + return x + + +class GptOssForCausalLM(nn.Module): + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config.hf_config + self.model = GptOssModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + self.lm_head = ParallelLMHead( + self.model_config.vocab_size, + self.model_config.hidden_size, + ) + self.logits_processor = LogitsProcessor(self.model_config.vocab_size) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + assert intermediate_tensors is None + assert inputs_embeds is None + return self.model(input_ids, positions) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + rename_mapping = { + "self_attn": "attn", + "input_layernorm.weight": "attn.norm.weight", + "post_attention_layernorm.weight": "mlp.norm.weight", + "embed_tokens": "embedding", + } + + def maybe_rename(name: str) -> str: + for remap_name, new_name in rename_mapping.items(): + if remap_name in name: + return name.replace(remap_name, new_name) + return name + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + mxfp4_block = 32 + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + intermediate_size = self.model_config.intermediate_size + intermediate_size_block = intermediate_size // mxfp4_block + per_rank_intermediate_size_block = cdiv(intermediate_size_block, + tp_size) + per_rank_intermediate_size = (per_rank_intermediate_size_block * + mxfp4_block) + + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + # Attention heads per rank + heads_per_rank = self.model_config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + use_ep = self.vllm_config.parallel_config.enable_expert_parallel + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.model_config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + for name, weight in weights: + # FIXME(woosuk): Remove this after testing. + weight = weight.cuda() + + if "gate_up_proj_blocks" in name: + # Handle MLP gate and up projection weights + new_name = name.replace("gate_up_proj_blocks", "w13_weight") + + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(num_experts, 2 * intermediate_size, + -1).contiguous() + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_blocks" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_blocks", "w2_weight") + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(num_experts, -1, + intermediate_size // 2).contiguous() + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., + tp_rank_start // 2:tp_rank_end // 2] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "gate_up_proj_scales" in name: + # Handle MLP gate and up projection weights scale + new_name = name.replace("gate_up_proj_scales", + "w13_weight_scale") + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_scales" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_scales", "w2_weight_scale") + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., tp_rank_start // + mxfp4_block:tp_rank_end // + mxfp4_block] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + elif "gate_up_proj_bias" in name: + # Handle MLP gate and up projection biases + new_name = name.replace("gate_up_proj_bias", "w13_bias") + + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_bias" in name: + # Handle MLP down projection bias + new_name = name.replace("down_proj_bias", "w2_bias") + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + weight_loader(param, + weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + name = name.replace("self_attn", "attn") + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + elif "q_proj" in name or "k_proj" in name or "v_proj" in name: + shard_id = ("q" if "q_proj" in name else + "k" if "k_proj" in name else "v") + name = name.replace("self_attn", "attn") + param_name = name.replace(f"{shard_id}_proj", "qkv") + param = params_dict[param_name] + weight_loader = param.weight_loader + weight_loader(param, weight, loaded_shard_id=shard_id) + loaded_params.add(param_name) + else: + # Handle all other weights with potential renaming + renamed_name = maybe_rename(name) + if renamed_name not in params_dict: + continue + param = params_dict[renamed_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(renamed_name) + + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9b6ab52d8680..c746e8ec3f29 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = { "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), + "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),