mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 12:19:49 +08:00
Add GPT-OSS model code and config [1/N] (#22327)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
796bae07c5
commit
de98252f49
@ -197,6 +197,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
{"6b": "EleutherAI/gpt-j-6b"}),
|
{"6b": "EleutherAI/gpt-j-6b"}),
|
||||||
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m",
|
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m",
|
||||||
{"1b": "EleutherAI/pythia-1.4b"}),
|
{"1b": "EleutherAI/pythia-1.4b"}),
|
||||||
|
"GptOssForCausalLM": _HfExamplesInfo("openai/gpt-oss-20b"),
|
||||||
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
|
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
|
||||||
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
|
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
|
||||||
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501
|
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501
|
||||||
|
|||||||
@ -247,6 +247,34 @@ class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
|
|||||||
config.max_model_len)
|
config.max_model_len)
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
decoding_config = vllm_config.decoding_config
|
||||||
|
if decoding_config.reasoning_backend == "":
|
||||||
|
decoding_config.reasoning_backend = "openai"
|
||||||
|
|
||||||
|
# Increase the max capture size from 512 to 1024 for performance.
|
||||||
|
# NOTE(woosuk): This will increase the number of CUDA graphs
|
||||||
|
# from 67 to 83.
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
if len(scheduler_config.cuda_graph_sizes) == 1:
|
||||||
|
max_capture_size = scheduler_config.cuda_graph_sizes[0]
|
||||||
|
# FIXME(woosuk): When using full cuda graph with FA3, the max
|
||||||
|
# supported size is 992.
|
||||||
|
if max_capture_size < 1024:
|
||||||
|
cuda_graph_sizes = [1, 2, 4]
|
||||||
|
# Step size 8 for small batch sizes
|
||||||
|
cuda_graph_sizes += [i for i in range(8, 256, 8)]
|
||||||
|
# Step size 16 for larger batch sizes
|
||||||
|
cuda_graph_sizes += [i for i in range(256, 1025, 16)]
|
||||||
|
scheduler_config.cuda_graph_sizes = cuda_graph_sizes
|
||||||
|
logger.info(
|
||||||
|
"Overriding max cuda graph capture size to "
|
||||||
|
"%d for performance.", 1024)
|
||||||
|
|
||||||
|
|
||||||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -345,4 +373,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
|||||||
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
||||||
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
||||||
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
||||||
|
"GptOssForCausalLM": GptOssConfig,
|
||||||
}
|
}
|
||||||
|
|||||||
472
vllm/model_executor/models/gpt_oss.py
Normal file
472
vllm/model_executor/models/gpt_oss.py
Normal file
@ -0,0 +1,472 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import nn
|
||||||
|
from transformers import GptOssConfig
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.attention import Attention, AttentionType
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
|
from .utils import extract_layer_index, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
|
class OAIAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GptOssConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = extract_layer_index(prefix)
|
||||||
|
self.head_dim = config.head_dim
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=config.max_position_embeddings,
|
||||||
|
base=config.rope_theta,
|
||||||
|
dtype=torch.float32,
|
||||||
|
rope_scaling={
|
||||||
|
"rope_type":
|
||||||
|
"yarn",
|
||||||
|
"factor":
|
||||||
|
config.rope_scaling["factor"],
|
||||||
|
"original_max_position_embeddings":
|
||||||
|
config.rope_scaling["original_max_position_embeddings"],
|
||||||
|
"beta_fast":
|
||||||
|
config.rope_ntk_beta,
|
||||||
|
"beta_slow":
|
||||||
|
config.rope_ntk_alpha,
|
||||||
|
},
|
||||||
|
is_neox_style=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
attention_sink_dtype = (
|
||||||
|
torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||||
|
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16)
|
||||||
|
self.sinks = torch.nn.Parameter(
|
||||||
|
torch.empty(config.num_attention_heads // tp_size,
|
||||||
|
dtype=attention_sink_dtype,
|
||||||
|
requires_grad=False))
|
||||||
|
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
self.q_size = self.num_attention_heads * self.head_dim // tp_size
|
||||||
|
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
|
||||||
|
self.qkv = QKVParallelLinear(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.num_attention_heads,
|
||||||
|
total_num_kv_heads=self.num_key_value_heads,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
input_size=self.num_attention_heads * self.head_dim,
|
||||||
|
output_size=self.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_local_attention_heads = config.num_attention_heads // tp_size
|
||||||
|
self.num_local_key_value_heads = config.num_key_value_heads // tp_size
|
||||||
|
|
||||||
|
# Only apply sliding window to every other layer
|
||||||
|
sliding_window = (config.sliding_window if self.layer_idx %
|
||||||
|
2 == 0 else None)
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_local_attention_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_local_key_value_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
per_layer_sliding_window=sliding_window,
|
||||||
|
attn_type=AttentionType.DECODER,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
sinks=self.sinks,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
positions: torch.Tensor) -> torch.Tensor:
|
||||||
|
t = self.norm(hidden_states)
|
||||||
|
|
||||||
|
qkv, _ = self.qkv(t)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
v = v.contiguous()
|
||||||
|
attn_output = self.attn(q, k, v)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return output + hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MLPBlock(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GptOssConfig,
|
||||||
|
layer_idx: int,
|
||||||
|
quant_config: QuantizationConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.num_experts = config.num_local_experts
|
||||||
|
self.experts_per_token = config.num_experts_per_tok
|
||||||
|
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
self.router = torch.nn.Linear(config.hidden_size,
|
||||||
|
config.num_local_experts,
|
||||||
|
dtype=torch.bfloat16)
|
||||||
|
assert config.intermediate_size % self.world_size == 0
|
||||||
|
self.experts = FusedMoE(num_experts=config.num_local_experts,
|
||||||
|
top_k=config.num_experts_per_token,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
reduce_results=True,
|
||||||
|
renormalize=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
apply_router_weight_on_input=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
t = self.norm(x)
|
||||||
|
g = self.router(t)
|
||||||
|
t = self.experts(hidden_states=t, router_logits=g)
|
||||||
|
return x + t
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GptOssConfig,
|
||||||
|
quant_config: QuantizationConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = extract_layer_index(prefix)
|
||||||
|
self.attn = OAIAttention(config, prefix=f"{prefix}.attn")
|
||||||
|
self.mlp = MLPBlock(config,
|
||||||
|
self.layer_idx,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
positions: torch.Tensor) -> torch.Tensor:
|
||||||
|
attn_output = self.attn(hidden_states, positions)
|
||||||
|
output = self.mlp(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class GptOssModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = vllm_config.model_config.hf_config
|
||||||
|
self.quant_config = vllm_config.quant_config
|
||||||
|
self.config.hidden_size = self.config.hidden_size
|
||||||
|
self.embedding = VocabParallelEmbedding(
|
||||||
|
self.config.vocab_size,
|
||||||
|
self.config.hidden_size,
|
||||||
|
)
|
||||||
|
self.layers = torch.nn.ModuleList([
|
||||||
|
TransformerBlock(
|
||||||
|
self.config,
|
||||||
|
quant_config=self.quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, f"block.{layer_idx}"),
|
||||||
|
) for layer_idx in range(self.config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.embedding(input_ids)
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, positions)
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.model_config = vllm_config.model_config.hf_config
|
||||||
|
self.model = GptOssModel(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"),
|
||||||
|
)
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.model_config.vocab_size,
|
||||||
|
self.model_config.hidden_size,
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
assert intermediate_tensors is None
|
||||||
|
assert inputs_embeds is None
|
||||||
|
return self.model(input_ids, positions)
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
rename_mapping = {
|
||||||
|
"self_attn": "attn",
|
||||||
|
"input_layernorm.weight": "attn.norm.weight",
|
||||||
|
"post_attention_layernorm.weight": "mlp.norm.weight",
|
||||||
|
"embed_tokens": "embedding",
|
||||||
|
}
|
||||||
|
|
||||||
|
def maybe_rename(name: str) -> str:
|
||||||
|
for remap_name, new_name in rename_mapping.items():
|
||||||
|
if remap_name in name:
|
||||||
|
return name.replace(remap_name, new_name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
mxfp4_block = 32
|
||||||
|
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
intermediate_size = self.model_config.intermediate_size
|
||||||
|
intermediate_size_block = intermediate_size // mxfp4_block
|
||||||
|
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
|
||||||
|
tp_size)
|
||||||
|
per_rank_intermediate_size = (per_rank_intermediate_size_block *
|
||||||
|
mxfp4_block)
|
||||||
|
|
||||||
|
# Calculate common slicing bounds for current rank
|
||||||
|
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||||
|
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||||
|
intermediate_size)
|
||||||
|
|
||||||
|
# Attention heads per rank
|
||||||
|
heads_per_rank = self.model_config.num_attention_heads // tp_size
|
||||||
|
head_start = tp_rank * heads_per_rank
|
||||||
|
|
||||||
|
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
||||||
|
ep_size = get_ep_group().world_size
|
||||||
|
ep_rank = get_ep_group().rank
|
||||||
|
num_experts = self.model_config.num_local_experts
|
||||||
|
experts_per_rank = num_experts // ep_size
|
||||||
|
ep_rank_start = ep_rank * experts_per_rank
|
||||||
|
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||||
|
|
||||||
|
for name, weight in weights:
|
||||||
|
# FIXME(woosuk): Remove this after testing.
|
||||||
|
weight = weight.cuda()
|
||||||
|
|
||||||
|
if "gate_up_proj_blocks" in name:
|
||||||
|
# Handle MLP gate and up projection weights
|
||||||
|
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
|
||||||
|
|
||||||
|
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||||
|
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||||
|
weight = weight.view(num_experts, 2 * intermediate_size,
|
||||||
|
-1).contiguous()
|
||||||
|
|
||||||
|
# Extract gate and up projection parts
|
||||||
|
# since the weight is shuffled, we can slice directly
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[:,
|
||||||
|
2 * tp_rank_start:2 * tp_rank_end,
|
||||||
|
...]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif "down_proj_blocks" in name:
|
||||||
|
# Handle MLP down projection weights
|
||||||
|
new_name = name.replace("down_proj_blocks", "w2_weight")
|
||||||
|
# same flatten here, but since 2 mx4 value are packed in 1
|
||||||
|
# uint8, divide by 2
|
||||||
|
weight = weight.view(num_experts, -1,
|
||||||
|
intermediate_size // 2).contiguous()
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[...,
|
||||||
|
tp_rank_start // 2:tp_rank_end // 2]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif "gate_up_proj_scales" in name:
|
||||||
|
# Handle MLP gate and up projection weights scale
|
||||||
|
new_name = name.replace("gate_up_proj_scales",
|
||||||
|
"w13_weight_scale")
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[:,
|
||||||
|
2 * tp_rank_start:2 * tp_rank_end,
|
||||||
|
...]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif "down_proj_scales" in name:
|
||||||
|
# Handle MLP down projection weights
|
||||||
|
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[..., tp_rank_start //
|
||||||
|
mxfp4_block:tp_rank_end //
|
||||||
|
mxfp4_block]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
elif "gate_up_proj_bias" in name:
|
||||||
|
# Handle MLP gate and up projection biases
|
||||||
|
new_name = name.replace("gate_up_proj_bias", "w13_bias")
|
||||||
|
|
||||||
|
# Extract gate and up projection bias parts
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[:,
|
||||||
|
2 * tp_rank_start:2 * tp_rank_end]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif "down_proj_bias" in name:
|
||||||
|
# Handle MLP down projection bias
|
||||||
|
new_name = name.replace("down_proj_bias", "w2_bias")
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
if use_ep:
|
||||||
|
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
# (only load on rank 0 to avoid duplication)
|
||||||
|
if tp_rank != 0:
|
||||||
|
weight.zero_()
|
||||||
|
weight_loader(param,
|
||||||
|
weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
elif "sinks" in name:
|
||||||
|
# Handle attention sinks (distributed across ranks)
|
||||||
|
name = name.replace("self_attn", "attn")
|
||||||
|
param = params_dict[name]
|
||||||
|
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||||
|
param.data.copy_(narrow_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||||
|
shard_id = ("q" if "q_proj" in name else
|
||||||
|
"k" if "k_proj" in name else "v")
|
||||||
|
name = name.replace("self_attn", "attn")
|
||||||
|
param_name = name.replace(f"{shard_id}_proj", "qkv")
|
||||||
|
param = params_dict[param_name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||||
|
loaded_params.add(param_name)
|
||||||
|
else:
|
||||||
|
# Handle all other weights with potential renaming
|
||||||
|
renamed_name = maybe_rename(name)
|
||||||
|
if renamed_name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[renamed_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, weight)
|
||||||
|
loaded_params.add(renamed_name)
|
||||||
|
|
||||||
|
return loaded_params
|
||||||
@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||||
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
|
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
|
||||||
|
"GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
|
||||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user