[Bugfix] Fix ReplicatedLinearWithLoRA (#27065)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-10-17 12:43:16 +08:00 committed by GitHub
parent fe3b9372ad
commit 87bc0c492f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -56,3 +56,15 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is ReplicatedLinear
def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora a if splitting for tensor parallelism."""
return lora_a
def slice_lora_b(
self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora b if splitting with tensor parallelism."""
return lora_b