mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Fix] Weight loading for GPTBigCode (#313)
This commit is contained in:
parent
85de093472
commit
598dc4b79a
@ -235,17 +235,28 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
def _expand_mqa_mha(qkv_array, n_head, head_dim):
|
||||
"""manipulates along axis=0 from MQA to MHA
|
||||
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
|
||||
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
|
||||
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
|
||||
|
||||
|
||||
TODO: this function is no longer needed once vllm supports MQA.
|
||||
"""
|
||||
qkv_array = qkv_array.numpy()
|
||||
|
||||
|
||||
dims_q = n_head * head_dim
|
||||
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0)
|
||||
# q is fine, but k & v have not replicated shape along the first axis
|
||||
@ -285,6 +296,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected parameter name {name}")
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
|
||||
@ -98,7 +98,9 @@ def load_tensor_parallel_weights(
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
assert param.shape == loaded_weight.shape
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f"{param_name} shape mismatch between model and checkpoint: "
|
||||
f"{param.shape} != {loaded_weight.shape}")
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user