mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 23:35:52 +08:00
support sharding llama2-70b on more than 8 GPUs (#1209)
Co-authored-by: JiCheng <247153481@qq.com>
This commit is contained in:
parent
ebe4d1db3a
commit
a60b353005
@ -103,8 +103,16 @@ class LlamaAttention(nn.Module):
|
|||||||
assert self.total_num_heads % tp_size == 0
|
assert self.total_num_heads % tp_size == 0
|
||||||
self.num_heads = self.total_num_heads // tp_size
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
self.total_num_kv_heads = num_kv_heads
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert self.total_num_kv_heads % tp_size == 0
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads)
|
||||||
self.head_dim = hidden_size // self.total_num_heads
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
@ -114,7 +122,8 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
self.qkv_proj = ParallelLinear.column(
|
self.qkv_proj = ParallelLinear.column(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
(self.total_num_heads +
|
||||||
|
2 * self.total_num_kv_heads * num_kv_heads_replicas) *
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
@ -323,11 +332,15 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
|
num_kv_heads_replicas = max(1,
|
||||||
|
tp_size // self.config.num_key_value_heads)
|
||||||
|
num_kv_heads_per_gpu = max(1,
|
||||||
|
self.config.num_key_value_heads // tp_size)
|
||||||
kv_proj_shard_size = (self.config.hidden_size //
|
kv_proj_shard_size = (self.config.hidden_size //
|
||||||
self.config.num_attention_heads *
|
self.config.num_attention_heads *
|
||||||
self.config.num_key_value_heads // tp_size)
|
num_kv_heads_per_gpu)
|
||||||
attention_weight_specs = [
|
attention_weight_specs = [
|
||||||
# (weight_name, shard_size, offset)
|
# (weight_name, shard_size, offset)
|
||||||
("q_proj", q_proj_shard_size, 0),
|
("q_proj", q_proj_shard_size, 0),
|
||||||
@ -363,9 +376,13 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
shard_size //= self.quant_config.pack_factor
|
shard_size //= self.quant_config.pack_factor
|
||||||
offset //= self.quant_config.pack_factor
|
offset //= self.quant_config.pack_factor
|
||||||
|
|
||||||
loaded_weight = loaded_weight[
|
if weight_name in ["k_proj", "v_proj"]:
|
||||||
shard_size * tensor_model_parallel_rank:shard_size *
|
shard_id = tp_rank // num_kv_heads_replicas
|
||||||
(tensor_model_parallel_rank + 1)]
|
else:
|
||||||
|
shard_id = tp_rank
|
||||||
|
loaded_weight = loaded_weight[shard_size *
|
||||||
|
shard_id:shard_size *
|
||||||
|
(shard_id + 1)]
|
||||||
param_slice = param.data[offset:offset + shard_size]
|
param_slice = param.data[offset:offset + shard_size]
|
||||||
assert param_slice.shape == loaded_weight.shape
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
|
||||||
@ -384,9 +401,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
param = param.T
|
param = param.T
|
||||||
|
|
||||||
shard_size = param.shape[0] // 2
|
shard_size = param.shape[0] // 2
|
||||||
loaded_weight = loaded_weight[
|
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||||
shard_size * tensor_model_parallel_rank:shard_size *
|
(tp_rank + 1)]
|
||||||
(tensor_model_parallel_rank + 1)]
|
|
||||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
(stride_id + 1)]
|
(stride_id + 1)]
|
||||||
assert param_slice.shape == loaded_weight.shape
|
assert param_slice.shape == loaded_weight.shape
|
||||||
@ -402,10 +418,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
|
|
||||||
if "embed_tokens" in name or "lm_head" in name:
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
tensor_model_parallel_rank)
|
tp_rank)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
column_parallel_weights,
|
column_parallel_weights,
|
||||||
row_parallel_weights,
|
row_parallel_weights, tp_rank)
|
||||||
tensor_model_parallel_rank)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user