[Fix] Fix GPTBigcoder for distributed execution (#503)

This commit is contained in:
Zhuohan Li 2023-07-24 18:36:33 -07:00 committed by GitHub
parent 1dde34e0f8
commit 7d5a155e4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -54,15 +54,30 @@ class GPTBigCodeAttention(nn.Module):
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.num_kv_heads = 1 if config.multi_query else self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size,
self.hidden_size + 2 * self.kv_dim,
bias=True,
gather_output=False,
perform_initialization=False)
self.multi_query = config.multi_query
if self.multi_query:
self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else:
self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.c_attn = ColumnParallelLinear(self.hidden_size,
self.hidden_size +
2 * self.kv_dim,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
@ -80,9 +95,14 @@ class GPTBigCodeAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
dim=-1)
if self.multi_query:
q, _ = self.c_attn_q(hidden_states)
kv = self.c_attn_kv(hidden_states)
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
else:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
@ -251,21 +271,9 @@ class GPTBigCodeForCausalLM(nn.Module):
# NOTE: "c_attn.bias" should not be skipped.
continue
param = state_dict[name]
if not name.startswith("transformer."):
name = "transformer." + name
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[
0] * tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
@ -291,9 +299,39 @@ class GPTBigCodeForCausalLM(nn.Module):
# Split the heads when using normal multi-head attention
wk = wk[head_size * head_start:head_size * head_end]
wv = wv[head_size * head_start:head_size * head_end]
# Else, keep the weights as is for multi-query attention
loaded_weight = torch.cat([wq, wk, wv], dim=0)
else:
# For multi-query attention, we split the query
# but replicate the key and value.
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("c_attn", "c_attn_q")
kv_weight_name = name.replace("c_attn", "c_attn_kv")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
continue
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param = state_dict[name]
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[
0] * tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,