mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:44:28 +08:00
[Fix] Weight loading for GPTBigCode (#313)
This commit is contained in:
parent
85de093472
commit
598dc4b79a
@ -236,6 +236,17 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
|
|
||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
|
||||||
|
if not name.startswith("transformer."):
|
||||||
|
name = "transformer." + name
|
||||||
|
|
||||||
|
if name == "transformer.wte.weight":
|
||||||
|
# Consider padding in the vocab size.
|
||||||
|
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
|
||||||
|
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||||
|
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
|
||||||
|
extra_rows = extra_rows.to(loaded_weight)
|
||||||
|
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||||
|
|
||||||
def _expand_mqa_mha(qkv_array, n_head, head_dim):
|
def _expand_mqa_mha(qkv_array, n_head, head_dim):
|
||||||
"""manipulates along axis=0 from MQA to MHA
|
"""manipulates along axis=0 from MQA to MHA
|
||||||
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
|
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
|
||||||
@ -285,6 +296,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
loaded_weight = loaded_weight.reshape(-1)
|
loaded_weight = loaded_weight.reshape(-1)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected parameter name {name}")
|
raise ValueError(f"Unexpected parameter name {name}")
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
self._column_parallel_weights,
|
self._column_parallel_weights,
|
||||||
self._row_parallel_weights,
|
self._row_parallel_weights,
|
||||||
|
|||||||
@ -98,7 +98,9 @@ def load_tensor_parallel_weights(
|
|||||||
shard_size * tensor_model_parallel_rank
|
shard_size * tensor_model_parallel_rank
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
break
|
break
|
||||||
assert param.shape == loaded_weight.shape
|
assert param.shape == loaded_weight.shape, (
|
||||||
|
f"{param_name} shape mismatch between model and checkpoint: "
|
||||||
|
f"{param.shape} != {loaded_weight.shape}")
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user