[Bugfix] Fix fully sharded LoRAs with Mixtral (#11390)

Signed-off-by: Jason Greene <jason.greene@redhat.com>
This commit is contained in:
Jason T. Greene 2024-12-22 09:25:10 -06:00 committed by GitHub
parent 72d9c316d3
commit f1d1bf6288
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 2 deletions

View File

@ -62,8 +62,9 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
@pytest.mark.parametrize("tp_size", [4])
@pytest.mark.parametrize("fully_shard", [True, False])
def test_mixtral_lora_all_target_modules(mixtral_lora_files_all_target_modules,
tp_size):
tp_size, fully_shard):
"""This LoRA model has all supported Mixtral target modules"""
if torch.cuda.device_count() < tp_size:
@ -82,6 +83,7 @@ def test_mixtral_lora_all_target_modules(mixtral_lora_files_all_target_modules,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
fully_sharded_loras=fully_shard,
max_lora_rank=32,
)

View File

@ -425,8 +425,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
if self.base_layer.skip_bias_add else None)
return output, output_bias
# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,