diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d3b5f7283e6fb..287d650da9806 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -103,8 +103,16 @@ class LlamaAttention(nn.Module): assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads - assert self.total_num_kv_heads % tp_size == 0 - self.num_kv_heads = self.total_num_kv_heads // tp_size + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -114,7 +122,8 @@ class LlamaAttention(nn.Module): self.qkv_proj = ParallelLinear.column( hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * + (self.total_num_heads + + 2 * self.total_num_kv_heads * num_kv_heads_replicas) * self.head_dim, bias=False, gather_output=False, @@ -323,11 +332,15 @@ class LlamaForCausalLM(nn.Module): row_parallel_weights.append(f"{layer}.{suffix}") tp_size = get_tensor_model_parallel_world_size() - tensor_model_parallel_rank = get_tensor_model_parallel_rank() + tp_rank = get_tensor_model_parallel_rank() q_proj_shard_size = (self.config.hidden_size // tp_size) + num_kv_heads_replicas = max(1, + tp_size // self.config.num_key_value_heads) + num_kv_heads_per_gpu = max(1, + self.config.num_key_value_heads // tp_size) kv_proj_shard_size = (self.config.hidden_size // self.config.num_attention_heads * - self.config.num_key_value_heads // tp_size) + num_kv_heads_per_gpu) attention_weight_specs = [ # (weight_name, shard_size, offset) ("q_proj", q_proj_shard_size, 0), @@ -363,9 +376,13 @@ class LlamaForCausalLM(nn.Module): shard_size //= self.quant_config.pack_factor offset //= self.quant_config.pack_factor - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] + if weight_name in ["k_proj", "v_proj"]: + shard_id = tp_rank // num_kv_heads_replicas + else: + shard_id = tp_rank + loaded_weight = loaded_weight[shard_size * + shard_id:shard_size * + (shard_id + 1)] param_slice = param.data[offset:offset + shard_size] assert param_slice.shape == loaded_weight.shape @@ -384,9 +401,8 @@ class LlamaForCausalLM(nn.Module): param = param.T shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] + loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * + (tp_rank + 1)] param_slice = param.data[shard_size * stride_id:shard_size * (stride_id + 1)] assert param_slice.shape == loaded_weight.shape @@ -402,10 +418,9 @@ class LlamaForCausalLM(nn.Module): if "embed_tokens" in name or "lm_head" in name: load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) + tp_rank) continue load_tensor_parallel_weights(param, loaded_weight, name, column_parallel_weights, - row_parallel_weights, - tensor_model_parallel_rank) + row_parallel_weights, tp_rank)