From 4f8584756d79234bd8207201273490907340b4bb Mon Sep 17 00:00:00 2001 From: zhaoyang-star Date: Tue, 22 Aug 2023 13:22:06 +0800 Subject: [PATCH] Fix mqa is false case in gpt_bigcode (#806) --- vllm/model_executor/models/gpt_bigcode.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 3bd8677b06289..0e6e71ec724b5 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -49,10 +49,11 @@ class GPTBigCodeAttention(nn.Module): super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads - tensor_model_parallel_world_size = ( + self.tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) - assert total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = total_num_heads // tensor_model_parallel_world_size + assert total_num_heads % self.tensor_model_parallel_world_size == 0 + self.num_heads = (total_num_heads // + self.tensor_model_parallel_world_size) self.head_dim = self.hidden_size // total_num_heads self.scale = self.head_dim**-0.5 @@ -101,7 +102,10 @@ class GPTBigCodeAttention(nn.Module): k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1) else: qkv, _ = self.c_attn(hidden_states) - q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim], + q, k, v = qkv.split([ + self.hidden_size // self.tensor_model_parallel_world_size, + self.kv_dim, self.kv_dim + ], dim=-1) key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, key_cache, value_cache, @@ -255,8 +259,6 @@ class GPTBigCodeForCausalLM(nn.Module): model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() @@ -286,7 +288,8 @@ class GPTBigCodeForCausalLM(nn.Module): hidden_size = self.config.hidden_size head_size = hidden_size // total_num_heads total_kv_size = head_size * total_num_kv_heads - num_heads = total_num_heads // tensor_model_parallel_world_size + num_heads = (total_num_heads // + self.tensor_model_parallel_world_size) head_start = tensor_model_parallel_rank * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads @@ -326,7 +329,7 @@ class GPTBigCodeForCausalLM(nn.Module): if name == "transformer.wte.weight": # Consider padding in the vocab size. padded_vocab_size = param.shape[ - 0] * tensor_model_parallel_world_size + 0] * self.tensor_model_parallel_world_size num_extra_rows = padded_vocab_size - self.config.vocab_size extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])