mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[Bugfix] Fix gpt_oss packed_modules_mapping (#28536)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
edb59a9470
commit
a9d18b5107
@ -92,7 +92,7 @@ class OAIAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_attention_heads,
|
||||
@ -129,7 +129,7 @@ class OAIAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, positions: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv(hidden_states)
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
v = v.contiguous()
|
||||
@ -606,9 +606,9 @@ class GptOssModel(nn.Module):
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv", ".q_proj", "q"),
|
||||
(".qkv", ".k_proj", "k"),
|
||||
(".qkv", ".v_proj", "v"),
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
]
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user