mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 12:55:01 +08:00
[Fix] Fix GPTBigcoder for distributed execution (#503)
This commit is contained in:
parent
1dde34e0f8
commit
7d5a155e4a
@ -54,15 +54,30 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||||
self.head_dim = self.hidden_size // total_num_heads
|
self.head_dim = self.hidden_size // total_num_heads
|
||||||
self.num_kv_heads = 1 if config.multi_query else self.num_heads
|
|
||||||
self.kv_dim = self.num_kv_heads * self.head_dim
|
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
self.multi_query = config.multi_query
|
||||||
self.hidden_size + 2 * self.kv_dim,
|
if self.multi_query:
|
||||||
bias=True,
|
self.num_kv_heads = 1
|
||||||
gather_output=False,
|
self.kv_dim = self.head_dim
|
||||||
perform_initialization=False)
|
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.c_attn_kv = nn.Linear(self.hidden_size,
|
||||||
|
2 * self.kv_dim,
|
||||||
|
bias=True)
|
||||||
|
else:
|
||||||
|
self.num_kv_heads = self.num_heads
|
||||||
|
self.kv_dim = self.num_kv_heads * self.head_dim
|
||||||
|
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||||
|
self.hidden_size +
|
||||||
|
2 * self.kv_dim,
|
||||||
|
bias=True,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
|
||||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
@ -80,9 +95,14 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.c_attn(hidden_states)
|
if self.multi_query:
|
||||||
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
|
q, _ = self.c_attn_q(hidden_states)
|
||||||
dim=-1)
|
kv = self.c_attn_kv(hidden_states)
|
||||||
|
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
|
||||||
|
else:
|
||||||
|
qkv, _ = self.c_attn(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
|
||||||
|
dim=-1)
|
||||||
key_cache, value_cache = kv_cache
|
key_cache, value_cache = kv_cache
|
||||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||||
input_metadata, cache_event)
|
input_metadata, cache_event)
|
||||||
@ -251,21 +271,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
# NOTE: "c_attn.bias" should not be skipped.
|
# NOTE: "c_attn.bias" should not be skipped.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = state_dict[name]
|
|
||||||
|
|
||||||
if not name.startswith("transformer."):
|
if not name.startswith("transformer."):
|
||||||
name = "transformer." + name
|
name = "transformer." + name
|
||||||
|
|
||||||
if name == "transformer.wte.weight":
|
|
||||||
# Consider padding in the vocab size.
|
|
||||||
padded_vocab_size = param.shape[
|
|
||||||
0] * tensor_model_parallel_world_size
|
|
||||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
|
||||||
extra_rows = torch.empty(num_extra_rows,
|
|
||||||
loaded_weight.shape[1])
|
|
||||||
extra_rows = extra_rows.to(loaded_weight)
|
|
||||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
|
||||||
|
|
||||||
# For the fused QKV linear layer, manually shard the weights.
|
# For the fused QKV linear layer, manually shard the weights.
|
||||||
if "c_attn" in name:
|
if "c_attn" in name:
|
||||||
# GPT-2's fused QKV has the shape of
|
# GPT-2's fused QKV has the shape of
|
||||||
@ -291,9 +299,39 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
# Split the heads when using normal multi-head attention
|
# Split the heads when using normal multi-head attention
|
||||||
wk = wk[head_size * head_start:head_size * head_end]
|
wk = wk[head_size * head_start:head_size * head_end]
|
||||||
wv = wv[head_size * head_start:head_size * head_end]
|
wv = wv[head_size * head_start:head_size * head_end]
|
||||||
# Else, keep the weights as is for multi-query attention
|
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||||
|
else:
|
||||||
|
# For multi-query attention, we split the query
|
||||||
|
# but replicate the key and value.
|
||||||
|
loaded_weight_q = wq
|
||||||
|
loaded_weight_kv = torch.cat([wk, wv], dim=0)
|
||||||
|
q_weight_name = name.replace("c_attn", "c_attn_q")
|
||||||
|
kv_weight_name = name.replace("c_attn", "c_attn_kv")
|
||||||
|
load_tensor_parallel_weights(state_dict[q_weight_name],
|
||||||
|
loaded_weight_q,
|
||||||
|
q_weight_name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
load_tensor_parallel_weights(state_dict[kv_weight_name],
|
||||||
|
loaded_weight_kv,
|
||||||
|
kv_weight_name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
param = state_dict[name]
|
||||||
|
|
||||||
|
if name == "transformer.wte.weight":
|
||||||
|
# Consider padding in the vocab size.
|
||||||
|
padded_vocab_size = param.shape[
|
||||||
|
0] * tensor_model_parallel_world_size
|
||||||
|
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||||
|
extra_rows = torch.empty(num_extra_rows,
|
||||||
|
loaded_weight.shape[1])
|
||||||
|
extra_rows = extra_rows.to(loaded_weight)
|
||||||
|
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
self._column_parallel_weights,
|
self._column_parallel_weights,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user