diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 1c118ba8c0d5..3bd8677b0628 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -54,15 +54,30 @@ class GPTBigCodeAttention(nn.Module): assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads - self.num_kv_heads = 1 if config.multi_query else self.num_heads - self.kv_dim = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 - self.c_attn = ColumnParallelLinear(self.hidden_size, - self.hidden_size + 2 * self.kv_dim, - bias=True, - gather_output=False, - perform_initialization=False) + self.multi_query = config.multi_query + if self.multi_query: + self.num_kv_heads = 1 + self.kv_dim = self.head_dim + self.c_attn_q = ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + gather_output=False, + perform_initialization=False) + self.c_attn_kv = nn.Linear(self.hidden_size, + 2 * self.kv_dim, + bias=True) + else: + self.num_kv_heads = self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.c_attn = ColumnParallelLinear(self.hidden_size, + self.hidden_size + + 2 * self.kv_dim, + bias=True, + gather_output=False, + perform_initialization=False) + self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True, @@ -80,9 +95,14 @@ class GPTBigCodeAttention(nn.Module): input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - qkv, _ = self.c_attn(hidden_states) - q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim], - dim=-1) + if self.multi_query: + q, _ = self.c_attn_q(hidden_states) + kv = self.c_attn_kv(hidden_states) + k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1) + else: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim], + dim=-1) key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata, cache_event) @@ -251,21 +271,9 @@ class GPTBigCodeForCausalLM(nn.Module): # NOTE: "c_attn.bias" should not be skipped. continue - param = state_dict[name] - if not name.startswith("transformer."): name = "transformer." + name - if name == "transformer.wte.weight": - # Consider padding in the vocab size. - padded_vocab_size = param.shape[ - 0] * tensor_model_parallel_world_size - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - # For the fused QKV linear layer, manually shard the weights. if "c_attn" in name: # GPT-2's fused QKV has the shape of @@ -291,9 +299,39 @@ class GPTBigCodeForCausalLM(nn.Module): # Split the heads when using normal multi-head attention wk = wk[head_size * head_start:head_size * head_end] wv = wv[head_size * head_start:head_size * head_end] - # Else, keep the weights as is for multi-query attention + loaded_weight = torch.cat([wq, wk, wv], dim=0) + else: + # For multi-query attention, we split the query + # but replicate the key and value. + loaded_weight_q = wq + loaded_weight_kv = torch.cat([wk, wv], dim=0) + q_weight_name = name.replace("c_attn", "c_attn_q") + kv_weight_name = name.replace("c_attn", "c_attn_kv") + load_tensor_parallel_weights(state_dict[q_weight_name], + loaded_weight_q, + q_weight_name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank) + load_tensor_parallel_weights(state_dict[kv_weight_name], + loaded_weight_kv, + kv_weight_name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank) + continue - loaded_weight = torch.cat([wq, wk, wv], dim=0) + param = state_dict[name] + + if name == "transformer.wte.weight": + # Consider padding in the vocab size. + padded_vocab_size = param.shape[ + 0] * tensor_model_parallel_world_size + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, + loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights,