[LoRA] Relax LoRA condition (#7146)

This commit is contained in:
Jee Jee Li 2024-08-06 09:57:25 +08:00 committed by GitHub
parent e3c664bfcb
commit 9118217f58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 8 additions and 7 deletions

View File

@ -420,7 +420,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
@pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
stage) -> None: stage) -> None:

View File

@ -25,7 +25,7 @@ HIDDEN_SIZES = [3424, 4096, 4097]
BATCHES = [1, 4, 16, 32] BATCHES = [1, 4, 16, 32]
NUM_LORA = [1, 4, 8, 16, 32, 64, 128] NUM_LORA = [1, 4, 8, 16, 32, 64, 128]
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [1, 4, 8, 16, 32, 64, 128] MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256]
SCALES = [0.5] SCALES = [0.5]
SEED = [0] SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"] CUDA_DEVICES = [f"cuda:{0}"]

View File

@ -1311,8 +1311,9 @@ class LoRAConfig:
long_lora_scaling_factors: Optional[Tuple[float]] = None long_lora_scaling_factors: Optional[Tuple[float]] = None
def __post_init__(self): def __post_init__(self):
# TODO: Increase the range of rank # Setting the maximum rank to 256 should be able to satisfy the vast
possible_max_ranks = (8, 16, 32, 64) # majority of applications.
possible_max_ranks = (8, 16, 32, 64, 128, 256)
possible_lora_extra_vocab_size = (0, 256, 512) possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks: if self.max_lora_rank not in possible_max_ranks:
raise ValueError( raise ValueError(

View File

@ -1073,10 +1073,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None, model_config: Optional[PretrainedConfig] = None,
) -> None: ) -> None:
# TODO: Verify if this condition can be relaxed # TODO: Verify if this condition can be further relaxed
if 32000 < self.base_layer.vocab_size > 128512: if 32000 < self.base_layer.vocab_size > 257024:
raise ValueError("When using LoRA, vocab size must be " raise ValueError("When using LoRA, vocab size must be "
"32000 >= vocab_size <= 128512") "32000 >= vocab_size <= 257024")
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,