From f780504d1294cbe28221d9d030b040384fa53d5d Mon Sep 17 00:00:00 2001 From: Chenhui Zhang Date: Tue, 16 Jan 2024 07:43:59 +0800 Subject: [PATCH] fix weigit loading for GQA with TP (#2379) --- vllm/model_executor/layers/linear.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5190de65d795..5e1d63a6a62e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -423,7 +423,10 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offset = shard_offset // param.pack_factor param_data = param_data.narrow(output_dim, shard_offset, shard_size) - shard_id = tp_rank // self.num_kv_head_replicas + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)