mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:25:33 +08:00
[Bugfix] Fix gpt_oss packed_modules_mapping (#28536)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
edb59a9470
commit
a9d18b5107
@ -92,7 +92,7 @@ class OAIAttention(nn.Module):
|
|||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
|
|
||||||
self.qkv = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
head_size=self.head_dim,
|
head_size=self.head_dim,
|
||||||
total_num_heads=self.num_attention_heads,
|
total_num_heads=self.num_attention_heads,
|
||||||
@ -129,7 +129,7 @@ class OAIAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, positions: torch.Tensor
|
self, hidden_states: torch.Tensor, positions: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
@ -606,9 +606,9 @@ class GptOssModel(nn.Module):
|
|||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
(".qkv", ".q_proj", "q"),
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
(".qkv", ".k_proj", "k"),
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
(".qkv", ".v_proj", "v"),
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
]
|
]
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user