mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
[Fix] Fix GPTBigcoder for distributed execution (#503)
This commit is contained in:
parent
1dde34e0f8
commit
7d5a155e4a
@ -54,15 +54,30 @@ class GPTBigCodeAttention(nn.Module):
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.num_kv_heads = 1 if config.multi_query else self.num_heads
|
||||
self.kv_dim = self.num_kv_heads * self.head_dim
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
self.hidden_size + 2 * self.kv_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.multi_query = config.multi_query
|
||||
if self.multi_query:
|
||||
self.num_kv_heads = 1
|
||||
self.kv_dim = self.head_dim
|
||||
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_attn_kv = nn.Linear(self.hidden_size,
|
||||
2 * self.kv_dim,
|
||||
bias=True)
|
||||
else:
|
||||
self.num_kv_heads = self.num_heads
|
||||
self.kv_dim = self.num_kv_heads * self.head_dim
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
self.hidden_size +
|
||||
2 * self.kv_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
|
||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
@ -80,9 +95,14 @@ class GPTBigCodeAttention(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
|
||||
dim=-1)
|
||||
if self.multi_query:
|
||||
q, _ = self.c_attn_q(hidden_states)
|
||||
kv = self.c_attn_kv(hidden_states)
|
||||
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
|
||||
else:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
|
||||
dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
@ -251,21 +271,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[
|
||||
0] * tensor_model_parallel_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of
|
||||
@ -291,9 +299,39 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
# Split the heads when using normal multi-head attention
|
||||
wk = wk[head_size * head_start:head_size * head_end]
|
||||
wv = wv[head_size * head_start:head_size * head_end]
|
||||
# Else, keep the weights as is for multi-query attention
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
else:
|
||||
# For multi-query attention, we split the query
|
||||
# but replicate the key and value.
|
||||
loaded_weight_q = wq
|
||||
loaded_weight_kv = torch.cat([wk, wv], dim=0)
|
||||
q_weight_name = name.replace("c_attn", "c_attn_q")
|
||||
kv_weight_name = name.replace("c_attn", "c_attn_kv")
|
||||
load_tensor_parallel_weights(state_dict[q_weight_name],
|
||||
loaded_weight_q,
|
||||
q_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
load_tensor_parallel_weights(state_dict[kv_weight_name],
|
||||
loaded_weight_kv,
|
||||
kv_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
param = state_dict[name]
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[
|
||||
0] * tensor_model_parallel_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user