diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 576d95a47154..52b0834cacb8 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -31,6 +31,8 @@ DEVICES = ([ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] if current_platform.is_cuda_alike() else ["cpu"]) +DEFAULT_DTYPE = torch.get_default_dtype() + @pytest.fixture(scope="function", autouse=True) def use_v0_only(monkeypatch: pytest.MonkeyPatch): @@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model): model = dummy_model manager = LoRAModelManager( model, 1, 1, 1, - LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), - torch.device(DEVICES[0])) + LoRAConfig(max_lora_rank=8, + max_cpu_loras=8, + max_loras=8, + lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0])) model = manager.model assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA) @@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) @@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) @@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=2, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) assert all(x is None for x in manager.lora_index_to_id) @@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): - lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + lora_config = LoRAConfig(max_lora_rank=8, + max_cpu_loras=4, + max_loras=4, + lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = LRUCacheWorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, device, @@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): # Should remove every LoRA not specified in the request. - lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + lora_config = LoRAConfig(max_lora_rank=8, + max_cpu_loras=4, + max_loras=4, + lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = WorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, device, @@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=2, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) model = manager.model diff --git a/vllm/config.py b/vllm/config.py index 3e5a17802f0f..41a30efea039 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2565,18 +2565,41 @@ class SpeculativeConfig: return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" +LoRADType = Literal["auto", "float16", "bfloat16"] + + +@config @dataclass class LoRAConfig: - max_lora_rank: int - max_loras: int + """Configuration for LoRA.""" + + max_lora_rank: int = 16 + """Max LoRA rank.""" + max_loras: int = 1 + """Max number of LoRAs in a single batch.""" fully_sharded_loras: bool = False + """By default, only half of the LoRA computation is sharded with tensor + parallelism. Enabling this will use the fully sharded layers. At high + sequence length, max rank or tensor parallel size, this is likely faster. + """ max_cpu_loras: Optional[int] = None - lora_dtype: Optional[Union[torch.dtype, str]] = None + """Maximum number of LoRAs to store in CPU memory. Must be >= than + `max_loras`.""" + lora_dtype: Union[torch.dtype, LoRADType] = "auto" + """Data type for LoRA. If auto, will default to base model dtype.""" lora_extra_vocab_size: int = 256 + """Maximum size of extra vocabulary that can be present in a LoRA adapter + (added to the base model vocabulary).""" # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 - long_lora_scaling_factors: Optional[tuple[float]] = None + long_lora_scaling_factors: Optional[tuple[float, ...]] = None + """Specify multiple scaling factors (which can be different from base model + scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters + trained with those scaling factors to be used at the same time. If not + specified, only adapters trained with the base model scaling factor are + allowed.""" bias_enabled: bool = False + """Enable bias for LoRA adapters.""" def compute_hash(self) -> str: """ @@ -2641,12 +2664,19 @@ class LoRAConfig: "V1 LoRA does not support long LoRA, please use V0.") +@config @dataclass class PromptAdapterConfig: - max_prompt_adapters: int - max_prompt_adapter_token: int + max_prompt_adapters: int = 1 + """Max number of PromptAdapters in a batch.""" + max_prompt_adapter_token: int = 0 + """Max number of PromptAdapters tokens.""" max_cpu_prompt_adapters: Optional[int] = None - prompt_adapter_dtype: Optional[torch.dtype] = None + """Maximum number of PromptAdapters to store in CPU memory. Must be >= than + `max_prompt_adapters`.""" + prompt_adapter_dtype: Union[torch.dtype, str] = "auto" + """Data type for PromptAdapter. If auto, will default to base model dtype. + """ def compute_hash(self) -> str: """ @@ -2678,7 +2708,7 @@ class PromptAdapterConfig: self.max_cpu_prompt_adapters = self.max_prompt_adapters def verify_with_model_config(self, model_config: ModelConfig): - if self.prompt_adapter_dtype in (None, "auto"): + if self.prompt_adapter_dtype == "auto": self.prompt_adapter_dtype = model_config.dtype elif isinstance(self.prompt_adapter_dtype, str): self.prompt_adapter_dtype = getattr(torch, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6d6b5ac02b14..9cb2aa797be5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,7 +7,7 @@ import json import re import threading from dataclasses import MISSING, dataclass, fields -from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type, +from typing import (Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast, get_args, get_origin) import torch @@ -192,18 +192,23 @@ class EngineArgs: get_field(MultiModalConfig, "limit_per_prompt") mm_processor_kwargs: Optional[Dict[str, Any]] = None disable_mm_preprocessor_cache: bool = False + # LoRA fields enable_lora: bool = False - enable_lora_bias: bool = False - max_loras: int = 1 - max_lora_rank: int = 16 + enable_lora_bias: bool = LoRAConfig.bias_enabled + max_loras: int = LoRAConfig.max_loras + max_lora_rank: int = LoRAConfig.max_lora_rank + fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras + max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras + lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype + lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size + long_lora_scaling_factors: Optional[tuple[float, ...]] = \ + LoRAConfig.long_lora_scaling_factors + # PromptAdapter fields enable_prompt_adapter: bool = False - max_prompt_adapters: int = 1 - max_prompt_adapter_token: int = 0 - fully_sharded_loras: bool = False - lora_extra_vocab_size: int = 256 - long_lora_scaling_factors: Optional[Tuple[float]] = None - lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' - max_cpu_loras: Optional[int] = None + max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters + max_prompt_adapter_token: int = \ + PromptAdapterConfig.max_prompt_adapter_token + device: Device = DeviceConfig.device num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs @@ -338,10 +343,21 @@ class EngineArgs: kwargs[name]["choices"] = choices choice_type = type(choices[0]) assert all(type(c) is choice_type for c in choices), ( - f"All choices must be of the same type. " + "All choices must be of the same type. " f"Got {choices} with types {[type(c) for c in choices]}" ) kwargs[name]["type"] = choice_type + elif can_be_type(field_type, tuple): + if is_type_in_union(field_type, tuple): + field_type = get_type_from_union(field_type, tuple) + dtypes = get_args(field_type) + dtype = dtypes[0] + assert all( + d is dtype for d in dtypes if d is not Ellipsis + ), ("All non-Ellipsis tuple elements must be of the same " + f"type. Got {dtypes}.") + kwargs[name]["type"] = dtype + kwargs[name]["nargs"] = "+" elif can_be_type(field_type, int): kwargs[name]["type"] = optional_int if optional else int elif can_be_type(field_type, float): @@ -685,70 +701,49 @@ class EngineArgs: 'inputs.') # LoRA related configs - parser.add_argument('--enable-lora', - action='store_true', - help='If True, enable handling of LoRA adapters.') - parser.add_argument('--enable-lora-bias', - action='store_true', - help='If True, enable bias for LoRA adapters.') - parser.add_argument('--max-loras', - type=int, - default=EngineArgs.max_loras, - help='Max number of LoRAs in a single batch.') - parser.add_argument('--max-lora-rank', - type=int, - default=EngineArgs.max_lora_rank, - help='Max LoRA rank.') - parser.add_argument( - '--lora-extra-vocab-size', - type=int, - default=EngineArgs.lora_extra_vocab_size, - help=('Maximum size of extra vocabulary that can be ' - 'present in a LoRA adapter (added to the base ' - 'model vocabulary).')) - parser.add_argument( + lora_kwargs = get_kwargs(LoRAConfig) + lora_group = parser.add_argument_group( + title="LoRAConfig", + description=LoRAConfig.__doc__, + ) + lora_group.add_argument( + '--enable-lora', + action=argparse.BooleanOptionalAction, + help='If True, enable handling of LoRA adapters.') + lora_group.add_argument('--enable-lora-bias', + **lora_kwargs["bias_enabled"]) + lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"]) + lora_group.add_argument('--max-lora-rank', + **lora_kwargs["max_lora_rank"]) + lora_group.add_argument('--lora-extra-vocab-size', + **lora_kwargs["lora_extra_vocab_size"]) + lora_group.add_argument( '--lora-dtype', - type=str, - default=EngineArgs.lora_dtype, - choices=['auto', 'float16', 'bfloat16'], - help=('Data type for LoRA. If auto, will default to ' - 'base model dtype.')) - parser.add_argument( - '--long-lora-scaling-factors', - type=optional_str, - default=EngineArgs.long_lora_scaling_factors, - help=('Specify multiple scaling factors (which can ' - 'be different from base model scaling factor ' - '- see eg. Long LoRA) to allow for multiple ' - 'LoRA adapters trained with those scaling ' - 'factors to be used at the same time. If not ' - 'specified, only adapters trained with the ' - 'base model scaling factor are allowed.')) - parser.add_argument( - '--max-cpu-loras', - type=int, - default=EngineArgs.max_cpu_loras, - help=('Maximum number of LoRAs to store in CPU memory. ' - 'Must be >= than max_loras.')) - parser.add_argument( - '--fully-sharded-loras', - action='store_true', - help=('By default, only half of the LoRA computation is ' - 'sharded with tensor parallelism. ' - 'Enabling this will use the fully sharded layers. ' - 'At high sequence length, max rank or ' - 'tensor parallel size, this is likely faster.')) - parser.add_argument('--enable-prompt-adapter', - action='store_true', - help='If True, enable handling of PromptAdapters.') - parser.add_argument('--max-prompt-adapters', - type=int, - default=EngineArgs.max_prompt_adapters, - help='Max number of PromptAdapters in a batch.') - parser.add_argument('--max-prompt-adapter-token', - type=int, - default=EngineArgs.max_prompt_adapter_token, - help='Max number of PromptAdapters tokens') + **lora_kwargs["lora_dtype"], + ) + lora_group.add_argument('--long-lora-scaling-factors', + **lora_kwargs["long_lora_scaling_factors"]) + lora_group.add_argument('--max-cpu-loras', + **lora_kwargs["max_cpu_loras"]) + lora_group.add_argument('--fully-sharded-loras', + **lora_kwargs["fully_sharded_loras"]) + + # PromptAdapter related configs + prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig) + prompt_adapter_group = parser.add_argument_group( + title="PromptAdapterConfig", + description=PromptAdapterConfig.__doc__, + ) + prompt_adapter_group.add_argument( + '--enable-prompt-adapter', + action=argparse.BooleanOptionalAction, + help='If True, enable handling of PromptAdapters.') + prompt_adapter_group.add_argument( + '--max-prompt-adapters', + **prompt_adapter_kwargs["max_prompt_adapters"]) + prompt_adapter_group.add_argument( + '--max-prompt-adapter-token', + **prompt_adapter_kwargs["max_prompt_adapter_token"]) # Device arguments device_kwargs = get_kwargs(DeviceConfig)