support sharding llama2-70b on more than 8 GPUs (#1209)

Co-authored-by: JiCheng <247153481@qq.com>
This commit is contained in:
Zhuohan Li 2023-10-02 15:26:33 -07:00 committed by GitHub
parent ebe4d1db3a
commit a60b353005
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -103,8 +103,16 @@ class LlamaAttention(nn.Module):
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0 if self.total_num_kv_heads >= tp_size:
self.num_kv_heads = self.total_num_kv_heads // tp_size # Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
@ -114,7 +122,8 @@ class LlamaAttention(nn.Module):
self.qkv_proj = ParallelLinear.column( self.qkv_proj = ParallelLinear.column(
hidden_size, hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) * (self.total_num_heads +
2 * self.total_num_kv_heads * num_kv_heads_replicas) *
self.head_dim, self.head_dim,
bias=False, bias=False,
gather_output=False, gather_output=False,
@ -323,11 +332,15 @@ class LlamaForCausalLM(nn.Module):
row_parallel_weights.append(f"{layer}.{suffix}") row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size) q_proj_shard_size = (self.config.hidden_size // tp_size)
num_kv_heads_replicas = max(1,
tp_size // self.config.num_key_value_heads)
num_kv_heads_per_gpu = max(1,
self.config.num_key_value_heads // tp_size)
kv_proj_shard_size = (self.config.hidden_size // kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads * self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size) num_kv_heads_per_gpu)
attention_weight_specs = [ attention_weight_specs = [
# (weight_name, shard_size, offset) # (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0), ("q_proj", q_proj_shard_size, 0),
@ -363,9 +376,13 @@ class LlamaForCausalLM(nn.Module):
shard_size //= self.quant_config.pack_factor shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[ if weight_name in ["k_proj", "v_proj"]:
shard_size * tensor_model_parallel_rank:shard_size * shard_id = tp_rank // num_kv_heads_replicas
(tensor_model_parallel_rank + 1)] else:
shard_id = tp_rank
loaded_weight = loaded_weight[shard_size *
shard_id:shard_size *
(shard_id + 1)]
param_slice = param.data[offset:offset + shard_size] param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
@ -384,9 +401,8 @@ class LlamaForCausalLM(nn.Module):
param = param.T param = param.T
shard_size = param.shape[0] // 2 shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
shard_size * tensor_model_parallel_rank:shard_size * (tp_rank + 1)]
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size * param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
@ -402,10 +418,9 @@ class LlamaForCausalLM(nn.Module):
if "embed_tokens" in name or "lm_head" in name: if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight, load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank) tp_rank)
continue continue
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
column_parallel_weights, column_parallel_weights,
row_parallel_weights, row_parallel_weights, tp_rank)
tensor_model_parallel_rank)