mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
fix weigit loading for GQA with TP (#2379)
This commit is contained in:
parent
bfc072addf
commit
f780504d12
@ -423,7 +423,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
if loaded_shard_id == "q":
|
||||
shard_id = tp_rank
|
||||
else:
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user