From eb825c1e7401a6d9ebb2b3d8d693df0069b80ccb Mon Sep 17 00:00:00 2001 From: lirui <30922859+lihuahua123@users.noreply.github.com> Date: Mon, 13 Nov 2023 07:53:12 +0800 Subject: [PATCH] Fix #1474 - AssertionError:assert param_slice.shape == loaded_weight.shape (#1631) --- vllm/model_executor/models/gpt_j.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 3606fdc76fb15..f61eab73b3a89 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -250,7 +250,7 @@ class GPTJForCausalLM(nn.Module): if att_weight_name not in name: continue param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[1] + shard_size = param.shape[0] // 3 loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * (tp_rank + 1)] param_slice = param.data[shard_size * stride_id:shard_size *