mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 10:46:05 +08:00
[Model] Add PP support and VLM backbone compatability for GPT-OSS (#23680)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
11a7fafaa8
commit
c5d004aaaf
@ -358,7 +358,7 @@ th {
|
||||
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ |
|
||||
| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | | ✅︎ |
|
||||
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ | ✅︎ |
|
||||
| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -11,7 +11,8 @@ from transformers import GptOssConfig
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -27,7 +28,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
@ -75,8 +79,6 @@ class OAIAttention(nn.Module):
|
||||
dtype=torch.bfloat16,
|
||||
requires_grad=False))
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
self.q_size = self.num_attention_heads * self.head_dim // tp_size
|
||||
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
|
||||
self.scaling = self.head_dim**-0.5
|
||||
@ -119,16 +121,13 @@ class OAIAttention(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
t = self.norm(hidden_states)
|
||||
|
||||
qkv, _ = self.qkv(t)
|
||||
qkv, _ = self.qkv(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
v = v.contiguous()
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
return output + hidden_states
|
||||
return output
|
||||
|
||||
|
||||
class MLPBlock(torch.nn.Module):
|
||||
@ -145,7 +144,6 @@ class MLPBlock(torch.nn.Module):
|
||||
self.num_experts = config.num_local_experts
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.router = torch.nn.Linear(config.hidden_size,
|
||||
config.num_local_experts,
|
||||
dtype=torch.bfloat16)
|
||||
@ -163,10 +161,9 @@ class MLPBlock(torch.nn.Module):
|
||||
activation="swigluoai")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
t = self.norm(x)
|
||||
g = self.router(t)
|
||||
t = self.experts(hidden_states=t, router_logits=g)
|
||||
return x + t
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(torch.nn.Module):
|
||||
@ -187,12 +184,28 @@ class TransformerBlock(torch.nn.Module):
|
||||
self.layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
attn_output = self.attn(hidden_states, positions)
|
||||
output = self.mlp(attn_output)
|
||||
return output
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.attn(hidden_states, positions)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
output = self.mlp(hidden_states)
|
||||
return output, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@ -214,22 +227,52 @@ class GptOssModel(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
)
|
||||
self.layers = torch.nn.ModuleList([
|
||||
TransformerBlock(
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
self.config.num_hidden_layers,
|
||||
lambda prefix: TransformerBlock(
|
||||
self.config,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, f"block.{layer_idx}"),
|
||||
) for layer_idx in range(self.config.num_hidden_layers)
|
||||
])
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], self.config.hidden_size))
|
||||
|
||||
def forward(self, input_ids: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
x = self.embedding(input_ids)
|
||||
for layer in self.layers:
|
||||
x = layer(x, positions)
|
||||
x = self.norm(x)
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embedding(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
x = inputs_embeds
|
||||
else:
|
||||
x = self.get_input_embeddings(input_ids)
|
||||
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
x = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
x, residual = layer(x, positions, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": x,
|
||||
"residual": residual
|
||||
})
|
||||
x, _ = self.norm(x, residual)
|
||||
return x
|
||||
|
||||
def _load_weights_mxfp4(
|
||||
@ -264,6 +307,10 @@ class GptOssModel(nn.Module):
|
||||
intermediate_size)
|
||||
|
||||
for name, weight in weights:
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# FIXME(woosuk): Remove this after testing.
|
||||
weight = weight.cuda()
|
||||
|
||||
@ -445,6 +492,10 @@ class GptOssModel(nn.Module):
|
||||
intermediate_size)
|
||||
|
||||
for name, weight in weights:
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
if ".w13_weight" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
# Extract gate and up projection parts
|
||||
@ -562,18 +613,15 @@ class GptOssModel(nn.Module):
|
||||
weights, stacked_params_mapping)
|
||||
|
||||
|
||||
class GptOssForCausalLM(nn.Module):
|
||||
class GptOssForCausalLM(nn.Module, SupportsPP):
|
||||
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
".self_attn.": ".attn.",
|
||||
".post_attention_layernorm.": ".mlp.norm.",
|
||||
},
|
||||
orig_to_new_suffix={
|
||||
".embed_tokens.weight": ".embedding.weight",
|
||||
".input_layernorm.weight": ".attn.norm.weight",
|
||||
".post_attention_layernorm.weight": ".mlp.norm.weight",
|
||||
|
||||
# MoE MXFP4 weights
|
||||
".gate_up_proj_blocks": ".w13_weight",
|
||||
@ -609,6 +657,11 @@ class GptOssForCausalLM(nn.Module):
|
||||
self.config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user