[Fix] Weight loading for GPTBigCode (#313)

This commit is contained in:
Zhuohan Li 2023-06-29 22:14:17 -07:00 committed by GitHub
parent 85de093472
commit 598dc4b79a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 4 deletions

View File

@ -235,17 +235,28 @@ class GPTBigCodeForCausalLM(nn.Module):
continue
param = state_dict[name]
if not name.startswith("transformer."):
name = "transformer." + name
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
def _expand_mqa_mha(qkv_array, n_head, head_dim):
"""manipulates along axis=0 from MQA to MHA
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
TODO: this function is no longer needed once vllm supports MQA.
"""
qkv_array = qkv_array.numpy()
dims_q = n_head * head_dim
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0)
# q is fine, but k & v have not replicated shape along the first axis
@ -285,6 +296,7 @@ class GPTBigCodeForCausalLM(nn.Module):
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,

View File

@ -98,7 +98,9 @@ def load_tensor_parallel_weights(
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: "
f"{param.shape} != {loaded_weight.shape}")
param.data.copy_(loaded_weight)