mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:44:57 +08:00
Improve configs - LoRAConfig + PromptAdapterConfig (#16980)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
0422ce109f
commit
0fa939e2d1
@ -31,6 +31,8 @@ DEVICES = ([
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
] if current_platform.is_cuda_alike() else ["cpu"])
|
||||
|
||||
DEFAULT_DTYPE = torch.get_default_dtype()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
|
||||
@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
manager = LoRAModelManager(
|
||||
model, 1, 1, 1,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
|
||||
torch.device(DEVICES[0]))
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=8,
|
||||
max_loras=8,
|
||||
lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
|
||||
model = manager.model
|
||||
assert isinstance(model.get_submodule("dense1"),
|
||||
ColumnParallelLinearWithLoRA)
|
||||
@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=3,
|
||||
max_loras=2),
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=3,
|
||||
max_loras=2),
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_loras=2),
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_loras=2),
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
model = manager.model
|
||||
|
||||
|
||||
@ -2565,18 +2565,41 @@ class SpeculativeConfig:
|
||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||
|
||||
|
||||
LoRADType = Literal["auto", "float16", "bfloat16"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
max_lora_rank: int
|
||||
max_loras: int
|
||||
"""Configuration for LoRA."""
|
||||
|
||||
max_lora_rank: int = 16
|
||||
"""Max LoRA rank."""
|
||||
max_loras: int = 1
|
||||
"""Max number of LoRAs in a single batch."""
|
||||
fully_sharded_loras: bool = False
|
||||
"""By default, only half of the LoRA computation is sharded with tensor
|
||||
parallelism. Enabling this will use the fully sharded layers. At high
|
||||
sequence length, max rank or tensor parallel size, this is likely faster.
|
||||
"""
|
||||
max_cpu_loras: Optional[int] = None
|
||||
lora_dtype: Optional[Union[torch.dtype, str]] = None
|
||||
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
|
||||
`max_loras`."""
|
||||
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
|
||||
"""Data type for LoRA. If auto, will default to base model dtype."""
|
||||
lora_extra_vocab_size: int = 256
|
||||
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
|
||||
(added to the base model vocabulary)."""
|
||||
# This is a constant.
|
||||
lora_vocab_padding_size: ClassVar[int] = 256
|
||||
long_lora_scaling_factors: Optional[tuple[float]] = None
|
||||
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
|
||||
"""Specify multiple scaling factors (which can be different from base model
|
||||
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
|
||||
trained with those scaling factors to be used at the same time. If not
|
||||
specified, only adapters trained with the base model scaling factor are
|
||||
allowed."""
|
||||
bias_enabled: bool = False
|
||||
"""Enable bias for LoRA adapters."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -2641,12 +2664,19 @@ class LoRAConfig:
|
||||
"V1 LoRA does not support long LoRA, please use V0.")
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PromptAdapterConfig:
|
||||
max_prompt_adapters: int
|
||||
max_prompt_adapter_token: int
|
||||
max_prompt_adapters: int = 1
|
||||
"""Max number of PromptAdapters in a batch."""
|
||||
max_prompt_adapter_token: int = 0
|
||||
"""Max number of PromptAdapters tokens."""
|
||||
max_cpu_prompt_adapters: Optional[int] = None
|
||||
prompt_adapter_dtype: Optional[torch.dtype] = None
|
||||
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
|
||||
`max_prompt_adapters`."""
|
||||
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
|
||||
"""Data type for PromptAdapter. If auto, will default to base model dtype.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -2678,7 +2708,7 @@ class PromptAdapterConfig:
|
||||
self.max_cpu_prompt_adapters = self.max_prompt_adapters
|
||||
|
||||
def verify_with_model_config(self, model_config: ModelConfig):
|
||||
if self.prompt_adapter_dtype in (None, "auto"):
|
||||
if self.prompt_adapter_dtype == "auto":
|
||||
self.prompt_adapter_dtype = model_config.dtype
|
||||
elif isinstance(self.prompt_adapter_dtype, str):
|
||||
self.prompt_adapter_dtype = getattr(torch,
|
||||
|
||||
@ -7,7 +7,7 @@ import json
|
||||
import re
|
||||
import threading
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type,
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
|
||||
TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import torch
|
||||
@ -192,18 +192,23 @@ class EngineArgs:
|
||||
get_field(MultiModalConfig, "limit_per_prompt")
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
disable_mm_preprocessor_cache: bool = False
|
||||
# LoRA fields
|
||||
enable_lora: bool = False
|
||||
enable_lora_bias: bool = False
|
||||
max_loras: int = 1
|
||||
max_lora_rank: int = 16
|
||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||
max_loras: int = LoRAConfig.max_loras
|
||||
max_lora_rank: int = LoRAConfig.max_lora_rank
|
||||
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
||||
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
||||
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
|
||||
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
|
||||
LoRAConfig.long_lora_scaling_factors
|
||||
# PromptAdapter fields
|
||||
enable_prompt_adapter: bool = False
|
||||
max_prompt_adapters: int = 1
|
||||
max_prompt_adapter_token: int = 0
|
||||
fully_sharded_loras: bool = False
|
||||
lora_extra_vocab_size: int = 256
|
||||
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
|
||||
max_cpu_loras: Optional[int] = None
|
||||
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
|
||||
max_prompt_adapter_token: int = \
|
||||
PromptAdapterConfig.max_prompt_adapter_token
|
||||
|
||||
device: Device = DeviceConfig.device
|
||||
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
||||
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
||||
@ -338,10 +343,21 @@ class EngineArgs:
|
||||
kwargs[name]["choices"] = choices
|
||||
choice_type = type(choices[0])
|
||||
assert all(type(c) is choice_type for c in choices), (
|
||||
f"All choices must be of the same type. "
|
||||
"All choices must be of the same type. "
|
||||
f"Got {choices} with types {[type(c) for c in choices]}"
|
||||
)
|
||||
kwargs[name]["type"] = choice_type
|
||||
elif can_be_type(field_type, tuple):
|
||||
if is_type_in_union(field_type, tuple):
|
||||
field_type = get_type_from_union(field_type, tuple)
|
||||
dtypes = get_args(field_type)
|
||||
dtype = dtypes[0]
|
||||
assert all(
|
||||
d is dtype for d in dtypes if d is not Ellipsis
|
||||
), ("All non-Ellipsis tuple elements must be of the same "
|
||||
f"type. Got {dtypes}.")
|
||||
kwargs[name]["type"] = dtype
|
||||
kwargs[name]["nargs"] = "+"
|
||||
elif can_be_type(field_type, int):
|
||||
kwargs[name]["type"] = optional_int if optional else int
|
||||
elif can_be_type(field_type, float):
|
||||
@ -685,70 +701,49 @@ class EngineArgs:
|
||||
'inputs.')
|
||||
|
||||
# LoRA related configs
|
||||
parser.add_argument('--enable-lora',
|
||||
action='store_true',
|
||||
help='If True, enable handling of LoRA adapters.')
|
||||
parser.add_argument('--enable-lora-bias',
|
||||
action='store_true',
|
||||
help='If True, enable bias for LoRA adapters.')
|
||||
parser.add_argument('--max-loras',
|
||||
type=int,
|
||||
default=EngineArgs.max_loras,
|
||||
help='Max number of LoRAs in a single batch.')
|
||||
parser.add_argument('--max-lora-rank',
|
||||
type=int,
|
||||
default=EngineArgs.max_lora_rank,
|
||||
help='Max LoRA rank.')
|
||||
parser.add_argument(
|
||||
'--lora-extra-vocab-size',
|
||||
type=int,
|
||||
default=EngineArgs.lora_extra_vocab_size,
|
||||
help=('Maximum size of extra vocabulary that can be '
|
||||
'present in a LoRA adapter (added to the base '
|
||||
'model vocabulary).'))
|
||||
parser.add_argument(
|
||||
lora_kwargs = get_kwargs(LoRAConfig)
|
||||
lora_group = parser.add_argument_group(
|
||||
title="LoRAConfig",
|
||||
description=LoRAConfig.__doc__,
|
||||
)
|
||||
lora_group.add_argument(
|
||||
'--enable-lora',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='If True, enable handling of LoRA adapters.')
|
||||
lora_group.add_argument('--enable-lora-bias',
|
||||
**lora_kwargs["bias_enabled"])
|
||||
lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
|
||||
lora_group.add_argument('--max-lora-rank',
|
||||
**lora_kwargs["max_lora_rank"])
|
||||
lora_group.add_argument('--lora-extra-vocab-size',
|
||||
**lora_kwargs["lora_extra_vocab_size"])
|
||||
lora_group.add_argument(
|
||||
'--lora-dtype',
|
||||
type=str,
|
||||
default=EngineArgs.lora_dtype,
|
||||
choices=['auto', 'float16', 'bfloat16'],
|
||||
help=('Data type for LoRA. If auto, will default to '
|
||||
'base model dtype.'))
|
||||
parser.add_argument(
|
||||
'--long-lora-scaling-factors',
|
||||
type=optional_str,
|
||||
default=EngineArgs.long_lora_scaling_factors,
|
||||
help=('Specify multiple scaling factors (which can '
|
||||
'be different from base model scaling factor '
|
||||
'- see eg. Long LoRA) to allow for multiple '
|
||||
'LoRA adapters trained with those scaling '
|
||||
'factors to be used at the same time. If not '
|
||||
'specified, only adapters trained with the '
|
||||
'base model scaling factor are allowed.'))
|
||||
parser.add_argument(
|
||||
'--max-cpu-loras',
|
||||
type=int,
|
||||
default=EngineArgs.max_cpu_loras,
|
||||
help=('Maximum number of LoRAs to store in CPU memory. '
|
||||
'Must be >= than max_loras.'))
|
||||
parser.add_argument(
|
||||
'--fully-sharded-loras',
|
||||
action='store_true',
|
||||
help=('By default, only half of the LoRA computation is '
|
||||
'sharded with tensor parallelism. '
|
||||
'Enabling this will use the fully sharded layers. '
|
||||
'At high sequence length, max rank or '
|
||||
'tensor parallel size, this is likely faster.'))
|
||||
parser.add_argument('--enable-prompt-adapter',
|
||||
action='store_true',
|
||||
help='If True, enable handling of PromptAdapters.')
|
||||
parser.add_argument('--max-prompt-adapters',
|
||||
type=int,
|
||||
default=EngineArgs.max_prompt_adapters,
|
||||
help='Max number of PromptAdapters in a batch.')
|
||||
parser.add_argument('--max-prompt-adapter-token',
|
||||
type=int,
|
||||
default=EngineArgs.max_prompt_adapter_token,
|
||||
help='Max number of PromptAdapters tokens')
|
||||
**lora_kwargs["lora_dtype"],
|
||||
)
|
||||
lora_group.add_argument('--long-lora-scaling-factors',
|
||||
**lora_kwargs["long_lora_scaling_factors"])
|
||||
lora_group.add_argument('--max-cpu-loras',
|
||||
**lora_kwargs["max_cpu_loras"])
|
||||
lora_group.add_argument('--fully-sharded-loras',
|
||||
**lora_kwargs["fully_sharded_loras"])
|
||||
|
||||
# PromptAdapter related configs
|
||||
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
|
||||
prompt_adapter_group = parser.add_argument_group(
|
||||
title="PromptAdapterConfig",
|
||||
description=PromptAdapterConfig.__doc__,
|
||||
)
|
||||
prompt_adapter_group.add_argument(
|
||||
'--enable-prompt-adapter',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='If True, enable handling of PromptAdapters.')
|
||||
prompt_adapter_group.add_argument(
|
||||
'--max-prompt-adapters',
|
||||
**prompt_adapter_kwargs["max_prompt_adapters"])
|
||||
prompt_adapter_group.add_argument(
|
||||
'--max-prompt-adapter-token',
|
||||
**prompt_adapter_kwargs["max_prompt_adapter_token"])
|
||||
|
||||
# Device arguments
|
||||
device_kwargs = get_kwargs(DeviceConfig)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user