[Model] Add PP support and VLM backbone compatability for GPT-OSS (#23680)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-08-28 16:03:28 +08:00 committed by GitHub
parent 11a7fafaa8
commit c5d004aaaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 87 additions and 34 deletions

View File

@ -358,7 +358,7 @@ th {
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ |
| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ |
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | | ✅︎ |
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ | ✅︎ |
| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -11,7 +11,8 @@ from transformers import GptOssConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank,
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -27,7 +28,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from .interfaces import SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -75,8 +79,6 @@ class OAIAttention(nn.Module):
dtype=torch.bfloat16,
requires_grad=False))
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.q_size = self.num_attention_heads * self.head_dim // tp_size
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
self.scaling = self.head_dim**-0.5
@ -119,16 +121,13 @@ class OAIAttention(nn.Module):
def forward(self, hidden_states: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
t = self.norm(hidden_states)
qkv, _ = self.qkv(t)
qkv, _ = self.qkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
v = v.contiguous()
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output + hidden_states
return output
class MLPBlock(torch.nn.Module):
@ -145,7 +144,6 @@ class MLPBlock(torch.nn.Module):
self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.router = torch.nn.Linear(config.hidden_size,
config.num_local_experts,
dtype=torch.bfloat16)
@ -163,10 +161,9 @@ class MLPBlock(torch.nn.Module):
activation="swigluoai")
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)
g = self.router(t)
t = self.experts(hidden_states=t, router_logits=g)
return x + t
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)
return x
class TransformerBlock(torch.nn.Module):
@ -187,12 +184,28 @@ class TransformerBlock(torch.nn.Module):
self.layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(self, hidden_states: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
attn_output = self.attn(hidden_states, positions)
output = self.mlp(attn_output)
return output
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.attn(hidden_states, positions)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
output = self.mlp(hidden_states)
return output, residual
@support_torch_compile
@ -214,22 +227,52 @@ class GptOssModel(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
)
self.layers = torch.nn.ModuleList([
TransformerBlock(
self.start_layer, self.end_layer, self.layers = make_layers(
self.config.num_hidden_layers,
lambda prefix: TransformerBlock(
self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, f"block.{layer_idx}"),
) for layer_idx in range(self.config.num_hidden_layers)
])
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size))
def forward(self, input_ids: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x, positions)
x = self.norm(x)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
x = inputs_embeds
else:
x = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
x = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
x, residual = layer(x, positions, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": x,
"residual": residual
})
x, _ = self.norm(x, residual)
return x
def _load_weights_mxfp4(
@ -264,6 +307,10 @@ class GptOssModel(nn.Module):
intermediate_size)
for name, weight in weights:
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# FIXME(woosuk): Remove this after testing.
weight = weight.cuda()
@ -445,6 +492,10 @@ class GptOssModel(nn.Module):
intermediate_size)
for name, weight in weights:
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ".w13_weight" in name:
# Handle MLP gate and up projection weights
# Extract gate and up projection parts
@ -562,18 +613,15 @@ class GptOssModel(nn.Module):
weights, stacked_params_mapping)
class GptOssForCausalLM(nn.Module):
class GptOssForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
".self_attn.": ".attn.",
".post_attention_layernorm.": ".mlp.norm.",
},
orig_to_new_suffix={
".embed_tokens.weight": ".embedding.weight",
".input_layernorm.weight": ".attn.norm.weight",
".post_attention_layernorm.weight": ".mlp.norm.weight",
# MoE MXFP4 weights
".gate_up_proj_blocks": ".w13_weight",
@ -609,6 +657,11 @@ class GptOssForCausalLM(nn.Module):
self.config.hidden_size,
)
self.logits_processor = LogitsProcessor(self.config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,