mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:45:00 +08:00
[Fix] Weight loading for GPTBigCode (#313)
This commit is contained in:
parent
85de093472
commit
598dc4b79a
@ -235,17 +235,28 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
|
||||||
|
if not name.startswith("transformer."):
|
||||||
|
name = "transformer." + name
|
||||||
|
|
||||||
|
if name == "transformer.wte.weight":
|
||||||
|
# Consider padding in the vocab size.
|
||||||
|
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
|
||||||
|
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||||
|
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
|
||||||
|
extra_rows = extra_rows.to(loaded_weight)
|
||||||
|
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||||
|
|
||||||
def _expand_mqa_mha(qkv_array, n_head, head_dim):
|
def _expand_mqa_mha(qkv_array, n_head, head_dim):
|
||||||
"""manipulates along axis=0 from MQA to MHA
|
"""manipulates along axis=0 from MQA to MHA
|
||||||
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
|
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
|
||||||
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
|
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
|
||||||
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
|
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
|
||||||
|
|
||||||
TODO: this function is no longer needed once vllm supports MQA.
|
TODO: this function is no longer needed once vllm supports MQA.
|
||||||
"""
|
"""
|
||||||
qkv_array = qkv_array.numpy()
|
qkv_array = qkv_array.numpy()
|
||||||
|
|
||||||
dims_q = n_head * head_dim
|
dims_q = n_head * head_dim
|
||||||
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0)
|
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0)
|
||||||
# q is fine, but k & v have not replicated shape along the first axis
|
# q is fine, but k & v have not replicated shape along the first axis
|
||||||
@ -285,6 +296,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
loaded_weight = loaded_weight.reshape(-1)
|
loaded_weight = loaded_weight.reshape(-1)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected parameter name {name}")
|
raise ValueError(f"Unexpected parameter name {name}")
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
self._column_parallel_weights,
|
self._column_parallel_weights,
|
||||||
self._row_parallel_weights,
|
self._row_parallel_weights,
|
||||||
|
|||||||
@ -98,7 +98,9 @@ def load_tensor_parallel_weights(
|
|||||||
shard_size * tensor_model_parallel_rank
|
shard_size * tensor_model_parallel_rank
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
break
|
break
|
||||||
assert param.shape == loaded_weight.shape
|
assert param.shape == loaded_weight.shape, (
|
||||||
|
f"{param_name} shape mismatch between model and checkpoint: "
|
||||||
|
f"{param.shape} != {loaded_weight.shape}")
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user