mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:55:37 +08:00
fix weigit loading for GQA with TP (#2379)
This commit is contained in:
parent
bfc072addf
commit
f780504d12
@ -423,6 +423,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
shard_offset = shard_offset // param.pack_factor
|
shard_offset = shard_offset // param.pack_factor
|
||||||
param_data = param_data.narrow(output_dim, shard_offset,
|
param_data = param_data.narrow(output_dim, shard_offset,
|
||||||
shard_size)
|
shard_size)
|
||||||
|
if loaded_shard_id == "q":
|
||||||
|
shard_id = tp_rank
|
||||||
|
else:
|
||||||
shard_id = tp_rank // self.num_kv_head_replicas
|
shard_id = tp_rank // self.num_kv_head_replicas
|
||||||
start_idx = shard_id * shard_size
|
start_idx = shard_id * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user