[Bugfix] Fix gpt_oss packed_modules_mapping (#28536)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-11-12 21:02:06 +08:00 committed by GitHub
parent edb59a9470
commit a9d18b5107
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -92,7 +92,7 @@ class OAIAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.qkv = QKVParallelLinear(
self.qkv_proj = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.num_attention_heads,
@ -129,7 +129,7 @@ class OAIAttention(nn.Module):
def forward(
self, hidden_states: torch.Tensor, positions: torch.Tensor
) -> torch.Tensor:
qkv, _ = self.qkv(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
v = v.contiguous()
@ -606,9 +606,9 @@ class GptOssModel(nn.Module):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv", ".q_proj", "q"),
(".qkv", ".k_proj", "k"),
(".qkv", ".v_proj", "v"),
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
tp_rank = get_tensor_model_parallel_rank()