[Bugfix] Fix ColumnParallelLinearWithLoRA slice (#11708)

Signed-off-by: ZincCat <zincchloride@outlook.com>
This commit is contained in:
ZincCat 2025-01-03 13:02:34 -08:00 committed by GitHub
parent 80c751e7f6
commit 61fed92c7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -479,7 +479,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# ColumnParallelLinear.
else:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
@ -490,7 +490,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
if bias is None:
return bias
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
bias = bias[start_idx:end_idx]