support sharding llama2-70b on more than 8 GPUs (#1209)

Co-authored-by: JiCheng <247153481@qq.com>
This commit is contained in:
Zhuohan Li 2023-10-02 15:26:33 -07:00 committed by GitHub
parent ebe4d1db3a
commit a60b353005
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -103,8 +103,16 @@ class LlamaAttention(nn.Module):
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
@ -114,7 +122,8 @@ class LlamaAttention(nn.Module):
self.qkv_proj = ParallelLinear.column(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
(self.total_num_heads +
2 * self.total_num_kv_heads * num_kv_heads_replicas) *
self.head_dim,
bias=False,
gather_output=False,
@ -323,11 +332,15 @@ class LlamaForCausalLM(nn.Module):
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
tp_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
num_kv_heads_replicas = max(1,
tp_size // self.config.num_key_value_heads)
num_kv_heads_per_gpu = max(1,
self.config.num_key_value_heads // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)
num_kv_heads_per_gpu)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
@ -363,9 +376,13 @@ class LlamaForCausalLM(nn.Module):
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
if weight_name in ["k_proj", "v_proj"]:
shard_id = tp_rank // num_kv_heads_replicas
else:
shard_id = tp_rank
loaded_weight = loaded_weight[shard_size *
shard_id:shard_size *
(shard_id + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
@ -384,9 +401,8 @@ class LlamaForCausalLM(nn.Module):
param = param.T
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
@ -402,10 +418,9 @@ class LlamaForCausalLM(nn.Module):
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
tp_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
column_parallel_weights,
row_parallel_weights,
tensor_model_parallel_rank)
row_parallel_weights, tp_rank)