Add GPT-OSS model code and config [1/N] (#22327)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-05 23:26:00 -07:00 committed by GitHub
parent 796bae07c5
commit de98252f49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 503 additions and 0 deletions

View File

@ -197,6 +197,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
{"6b": "EleutherAI/gpt-j-6b"}),
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m",
{"1b": "EleutherAI/pythia-1.4b"}),
"GptOssForCausalLM": _HfExamplesInfo("openai/gpt-oss-20b"),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501

View File

@ -247,6 +247,34 @@ class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
config.max_model_len)
class GptOssConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
decoding_config = vllm_config.decoding_config
if decoding_config.reasoning_backend == "":
decoding_config.reasoning_backend = "openai"
# Increase the max capture size from 512 to 1024 for performance.
# NOTE(woosuk): This will increase the number of CUDA graphs
# from 67 to 83.
scheduler_config = vllm_config.scheduler_config
if len(scheduler_config.cuda_graph_sizes) == 1:
max_capture_size = scheduler_config.cuda_graph_sizes[0]
# FIXME(woosuk): When using full cuda graph with FA3, the max
# supported size is 992.
if max_capture_size < 1024:
cuda_graph_sizes = [1, 2, 4]
# Step size 8 for small batch sizes
cuda_graph_sizes += [i for i in range(8, 256, 8)]
# Step size 16 for larger batch sizes
cuda_graph_sizes += [i for i in range(256, 1025, 16)]
scheduler_config.cuda_graph_sizes = cuda_graph_sizes
logger.info(
"Overriding max cuda graph capture size to "
"%d for performance.", 1024)
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@classmethod
@ -345,4 +373,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
"GptOssForCausalLM": GptOssConfig,
}

View File

@ -0,0 +1,472 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig
from vllm import envs
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from .utils import extract_layer_index, maybe_prefix
class OAIAttention(nn.Module):
def __init__(
self,
config: GptOssConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.head_dim = config.head_dim
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
base=config.rope_theta,
dtype=torch.float32,
rope_scaling={
"rope_type":
"yarn",
"factor":
config.rope_scaling["factor"],
"original_max_position_embeddings":
config.rope_scaling["original_max_position_embeddings"],
"beta_fast":
config.rope_ntk_beta,
"beta_slow":
config.rope_ntk_alpha,
},
is_neox_style=True,
)
tp_size = get_tensor_model_parallel_world_size()
attention_sink_dtype = (
torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16)
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size,
dtype=attention_sink_dtype,
requires_grad=False))
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.q_size = self.num_attention_heads * self.head_dim // tp_size
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.qkv = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.num_attention_heads,
total_num_kv_heads=self.num_key_value_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.num_attention_heads * self.head_dim,
output_size=self.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.num_local_attention_heads = config.num_attention_heads // tp_size
self.num_local_key_value_heads = config.num_key_value_heads // tp_size
# Only apply sliding window to every other layer
sliding_window = (config.sliding_window if self.layer_idx %
2 == 0 else None)
self.attn = Attention(
self.num_local_attention_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_local_key_value_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=AttentionType.DECODER,
prefix=f"{prefix}.attn",
sinks=self.sinks,
)
def forward(self, hidden_states: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
t = self.norm(hidden_states)
qkv, _ = self.qkv(t)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
v = v.contiguous()
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output + hidden_states
class MLPBlock(torch.nn.Module):
def __init__(
self,
config: GptOssConfig,
layer_idx: int,
quant_config: QuantizationConfig,
prefix: str = "",
):
super().__init__()
self.layer_idx = layer_idx
self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.router = torch.nn.Linear(config.hidden_size,
config.num_local_experts,
dtype=torch.bfloat16)
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(num_experts=config.num_local_experts,
top_k=config.num_experts_per_token,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
apply_router_weight_on_input=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)
g = self.router(t)
t = self.experts(hidden_states=t, router_logits=g)
return x + t
class TransformerBlock(torch.nn.Module):
def __init__(
self,
config: GptOssConfig,
quant_config: QuantizationConfig,
prefix: str = "",
):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.attn = OAIAttention(config, prefix=f"{prefix}.attn")
self.mlp = MLPBlock(config,
self.layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
def forward(self, hidden_states: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
attn_output = self.attn(hidden_states, positions)
output = self.mlp(attn_output)
return output
@support_torch_compile
class GptOssModel(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
)
self.layers = torch.nn.ModuleList([
TransformerBlock(
self.config,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, f"block.{layer_idx}"),
) for layer_idx in range(self.config.num_hidden_layers)
])
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
def forward(self, input_ids: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x, positions)
x = self.norm(x)
return x
class GptOssForCausalLM(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config.hf_config
self.model = GptOssModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.lm_head = ParallelLMHead(
self.model_config.vocab_size,
self.model_config.hidden_size,
)
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
assert intermediate_tensors is None
assert inputs_embeds is None
return self.model(input_ids, positions)
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
"post_attention_layernorm.weight": "mlp.norm.weight",
"embed_tokens": "embedding",
}
def maybe_rename(name: str) -> str:
for remap_name, new_name in rename_mapping.items():
if remap_name in name:
return name.replace(remap_name, new_name)
return name
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
mxfp4_block = 32
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.model_config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
tp_size)
per_rank_intermediate_size = (per_rank_intermediate_size_block *
mxfp4_block)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
# Attention heads per rank
heads_per_rank = self.model_config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.model_config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
for name, weight in weights:
# FIXME(woosuk): Remove this after testing.
weight = weight.cuda()
if "gate_up_proj_blocks" in name:
# Handle MLP gate and up projection weights
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(num_experts, 2 * intermediate_size,
-1).contiguous()
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_blocks" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_blocks", "w2_weight")
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight = weight.view(num_experts, -1,
intermediate_size // 2).contiguous()
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[...,
tp_rank_start // 2:tp_rank_end // 2]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "gate_up_proj_scales" in name:
# Handle MLP gate and up projection weights scale
new_name = name.replace("gate_up_proj_scales",
"w13_weight_scale")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_scales" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_scales", "w2_weight_scale")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[..., tp_rank_start //
mxfp4_block:tp_rank_end //
mxfp4_block]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_bias")
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_bias" in name:
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_bias")
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
weight_loader(param,
weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
shard_id = ("q" if "q_proj" in name else
"k" if "k_proj" in name else "v")
name = name.replace("self_attn", "attn")
param_name = name.replace(f"{shard_id}_proj", "qkv")
param = params_dict[param_name]
weight_loader = param.weight_loader
weight_loader(param, weight, loaded_shard_id=shard_id)
loaded_params.add(param_name)
else:
# Handle all other weights with potential renaming
renamed_name = maybe_rename(name)
if renamed_name not in params_dict:
continue
param = params_dict[renamed_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(renamed_name)
return loaded_params

View File

@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = {
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
"GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),