mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 04:14:59 +08:00
Improve configs - LoRAConfig + PromptAdapterConfig (#16980)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
0422ce109f
commit
0fa939e2d1
@ -31,6 +31,8 @@ DEVICES = ([
|
|||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
] if current_platform.is_cuda_alike() else ["cpu"])
|
] if current_platform.is_cuda_alike() else ["cpu"])
|
||||||
|
|
||||||
|
DEFAULT_DTYPE = torch.get_default_dtype()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
|
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
|
||||||
@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model):
|
|||||||
model = dummy_model
|
model = dummy_model
|
||||||
manager = LoRAModelManager(
|
manager = LoRAModelManager(
|
||||||
model, 1, 1, 1,
|
model, 1, 1, 1,
|
||||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
|
LoRAConfig(max_lora_rank=8,
|
||||||
torch.device(DEVICES[0]))
|
max_cpu_loras=8,
|
||||||
|
max_loras=8,
|
||||||
|
lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
|
||||||
model = manager.model
|
model = manager.model
|
||||||
assert isinstance(model.get_submodule("dense1"),
|
assert isinstance(model.get_submodule("dense1"),
|
||||||
ColumnParallelLinearWithLoRA)
|
ColumnParallelLinearWithLoRA)
|
||||||
@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
|||||||
2,
|
2,
|
||||||
LoRAConfig(max_lora_rank=8,
|
LoRAConfig(max_lora_rank=8,
|
||||||
max_cpu_loras=3,
|
max_cpu_loras=3,
|
||||||
max_loras=2),
|
max_loras=2,
|
||||||
|
lora_dtype=DEFAULT_DTYPE),
|
||||||
device=device)
|
device=device)
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
assert manager.add_adapter(model_lora1)
|
assert manager.add_adapter(model_lora1)
|
||||||
@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
|||||||
2,
|
2,
|
||||||
LoRAConfig(max_lora_rank=8,
|
LoRAConfig(max_lora_rank=8,
|
||||||
max_cpu_loras=3,
|
max_cpu_loras=3,
|
||||||
max_loras=2),
|
max_loras=2,
|
||||||
|
lora_dtype=DEFAULT_DTYPE),
|
||||||
device=device)
|
device=device)
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
assert manager.add_adapter(model_lora1)
|
assert manager.add_adapter(model_lora1)
|
||||||
@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
|||||||
2,
|
2,
|
||||||
LoRAConfig(max_lora_rank=8,
|
LoRAConfig(max_lora_rank=8,
|
||||||
max_cpu_loras=2,
|
max_cpu_loras=2,
|
||||||
max_loras=2),
|
max_loras=2,
|
||||||
|
lora_dtype=DEFAULT_DTYPE),
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
|||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||||
sql_lora_files, device):
|
sql_lora_files, device):
|
||||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
lora_config = LoRAConfig(max_lora_rank=8,
|
||||||
|
max_cpu_loras=4,
|
||||||
|
max_loras=4,
|
||||||
|
lora_dtype=DEFAULT_DTYPE)
|
||||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||||
@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||||
sql_lora_files, device):
|
sql_lora_files, device):
|
||||||
# Should remove every LoRA not specified in the request.
|
# Should remove every LoRA not specified in the request.
|
||||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
lora_config = LoRAConfig(max_lora_rank=8,
|
||||||
|
max_cpu_loras=4,
|
||||||
|
max_loras=4,
|
||||||
|
lora_dtype=DEFAULT_DTYPE)
|
||||||
worker_adapter_manager = WorkerLoRAManager(
|
worker_adapter_manager = WorkerLoRAManager(
|
||||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||||
@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
|||||||
2,
|
2,
|
||||||
LoRAConfig(max_lora_rank=8,
|
LoRAConfig(max_lora_rank=8,
|
||||||
max_cpu_loras=2,
|
max_cpu_loras=2,
|
||||||
max_loras=2),
|
max_loras=2,
|
||||||
|
lora_dtype=DEFAULT_DTYPE),
|
||||||
device=device)
|
device=device)
|
||||||
model = manager.model
|
model = manager.model
|
||||||
|
|
||||||
|
|||||||
@ -2565,18 +2565,41 @@ class SpeculativeConfig:
|
|||||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||||
|
|
||||||
|
|
||||||
|
LoRADType = Literal["auto", "float16", "bfloat16"]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoRAConfig:
|
class LoRAConfig:
|
||||||
max_lora_rank: int
|
"""Configuration for LoRA."""
|
||||||
max_loras: int
|
|
||||||
|
max_lora_rank: int = 16
|
||||||
|
"""Max LoRA rank."""
|
||||||
|
max_loras: int = 1
|
||||||
|
"""Max number of LoRAs in a single batch."""
|
||||||
fully_sharded_loras: bool = False
|
fully_sharded_loras: bool = False
|
||||||
|
"""By default, only half of the LoRA computation is sharded with tensor
|
||||||
|
parallelism. Enabling this will use the fully sharded layers. At high
|
||||||
|
sequence length, max rank or tensor parallel size, this is likely faster.
|
||||||
|
"""
|
||||||
max_cpu_loras: Optional[int] = None
|
max_cpu_loras: Optional[int] = None
|
||||||
lora_dtype: Optional[Union[torch.dtype, str]] = None
|
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
|
||||||
|
`max_loras`."""
|
||||||
|
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
|
||||||
|
"""Data type for LoRA. If auto, will default to base model dtype."""
|
||||||
lora_extra_vocab_size: int = 256
|
lora_extra_vocab_size: int = 256
|
||||||
|
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
|
||||||
|
(added to the base model vocabulary)."""
|
||||||
# This is a constant.
|
# This is a constant.
|
||||||
lora_vocab_padding_size: ClassVar[int] = 256
|
lora_vocab_padding_size: ClassVar[int] = 256
|
||||||
long_lora_scaling_factors: Optional[tuple[float]] = None
|
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
|
||||||
|
"""Specify multiple scaling factors (which can be different from base model
|
||||||
|
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
|
||||||
|
trained with those scaling factors to be used at the same time. If not
|
||||||
|
specified, only adapters trained with the base model scaling factor are
|
||||||
|
allowed."""
|
||||||
bias_enabled: bool = False
|
bias_enabled: bool = False
|
||||||
|
"""Enable bias for LoRA adapters."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -2641,12 +2664,19 @@ class LoRAConfig:
|
|||||||
"V1 LoRA does not support long LoRA, please use V0.")
|
"V1 LoRA does not support long LoRA, please use V0.")
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptAdapterConfig:
|
class PromptAdapterConfig:
|
||||||
max_prompt_adapters: int
|
max_prompt_adapters: int = 1
|
||||||
max_prompt_adapter_token: int
|
"""Max number of PromptAdapters in a batch."""
|
||||||
|
max_prompt_adapter_token: int = 0
|
||||||
|
"""Max number of PromptAdapters tokens."""
|
||||||
max_cpu_prompt_adapters: Optional[int] = None
|
max_cpu_prompt_adapters: Optional[int] = None
|
||||||
prompt_adapter_dtype: Optional[torch.dtype] = None
|
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
|
||||||
|
`max_prompt_adapters`."""
|
||||||
|
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
|
||||||
|
"""Data type for PromptAdapter. If auto, will default to base model dtype.
|
||||||
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -2678,7 +2708,7 @@ class PromptAdapterConfig:
|
|||||||
self.max_cpu_prompt_adapters = self.max_prompt_adapters
|
self.max_cpu_prompt_adapters = self.max_prompt_adapters
|
||||||
|
|
||||||
def verify_with_model_config(self, model_config: ModelConfig):
|
def verify_with_model_config(self, model_config: ModelConfig):
|
||||||
if self.prompt_adapter_dtype in (None, "auto"):
|
if self.prompt_adapter_dtype == "auto":
|
||||||
self.prompt_adapter_dtype = model_config.dtype
|
self.prompt_adapter_dtype = model_config.dtype
|
||||||
elif isinstance(self.prompt_adapter_dtype, str):
|
elif isinstance(self.prompt_adapter_dtype, str):
|
||||||
self.prompt_adapter_dtype = getattr(torch,
|
self.prompt_adapter_dtype = getattr(torch,
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from dataclasses import MISSING, dataclass, fields
|
from dataclasses import MISSING, dataclass, fields
|
||||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type,
|
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
|
||||||
TypeVar, Union, cast, get_args, get_origin)
|
TypeVar, Union, cast, get_args, get_origin)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -192,18 +192,23 @@ class EngineArgs:
|
|||||||
get_field(MultiModalConfig, "limit_per_prompt")
|
get_field(MultiModalConfig, "limit_per_prompt")
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||||
disable_mm_preprocessor_cache: bool = False
|
disable_mm_preprocessor_cache: bool = False
|
||||||
|
# LoRA fields
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
enable_lora_bias: bool = False
|
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||||
max_loras: int = 1
|
max_loras: int = LoRAConfig.max_loras
|
||||||
max_lora_rank: int = 16
|
max_lora_rank: int = LoRAConfig.max_lora_rank
|
||||||
|
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
||||||
|
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
||||||
|
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||||
|
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
|
||||||
|
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
|
||||||
|
LoRAConfig.long_lora_scaling_factors
|
||||||
|
# PromptAdapter fields
|
||||||
enable_prompt_adapter: bool = False
|
enable_prompt_adapter: bool = False
|
||||||
max_prompt_adapters: int = 1
|
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
|
||||||
max_prompt_adapter_token: int = 0
|
max_prompt_adapter_token: int = \
|
||||||
fully_sharded_loras: bool = False
|
PromptAdapterConfig.max_prompt_adapter_token
|
||||||
lora_extra_vocab_size: int = 256
|
|
||||||
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
|
||||||
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
|
|
||||||
max_cpu_loras: Optional[int] = None
|
|
||||||
device: Device = DeviceConfig.device
|
device: Device = DeviceConfig.device
|
||||||
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
||||||
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
||||||
@ -338,10 +343,21 @@ class EngineArgs:
|
|||||||
kwargs[name]["choices"] = choices
|
kwargs[name]["choices"] = choices
|
||||||
choice_type = type(choices[0])
|
choice_type = type(choices[0])
|
||||||
assert all(type(c) is choice_type for c in choices), (
|
assert all(type(c) is choice_type for c in choices), (
|
||||||
f"All choices must be of the same type. "
|
"All choices must be of the same type. "
|
||||||
f"Got {choices} with types {[type(c) for c in choices]}"
|
f"Got {choices} with types {[type(c) for c in choices]}"
|
||||||
)
|
)
|
||||||
kwargs[name]["type"] = choice_type
|
kwargs[name]["type"] = choice_type
|
||||||
|
elif can_be_type(field_type, tuple):
|
||||||
|
if is_type_in_union(field_type, tuple):
|
||||||
|
field_type = get_type_from_union(field_type, tuple)
|
||||||
|
dtypes = get_args(field_type)
|
||||||
|
dtype = dtypes[0]
|
||||||
|
assert all(
|
||||||
|
d is dtype for d in dtypes if d is not Ellipsis
|
||||||
|
), ("All non-Ellipsis tuple elements must be of the same "
|
||||||
|
f"type. Got {dtypes}.")
|
||||||
|
kwargs[name]["type"] = dtype
|
||||||
|
kwargs[name]["nargs"] = "+"
|
||||||
elif can_be_type(field_type, int):
|
elif can_be_type(field_type, int):
|
||||||
kwargs[name]["type"] = optional_int if optional else int
|
kwargs[name]["type"] = optional_int if optional else int
|
||||||
elif can_be_type(field_type, float):
|
elif can_be_type(field_type, float):
|
||||||
@ -685,70 +701,49 @@ class EngineArgs:
|
|||||||
'inputs.')
|
'inputs.')
|
||||||
|
|
||||||
# LoRA related configs
|
# LoRA related configs
|
||||||
parser.add_argument('--enable-lora',
|
lora_kwargs = get_kwargs(LoRAConfig)
|
||||||
action='store_true',
|
lora_group = parser.add_argument_group(
|
||||||
help='If True, enable handling of LoRA adapters.')
|
title="LoRAConfig",
|
||||||
parser.add_argument('--enable-lora-bias',
|
description=LoRAConfig.__doc__,
|
||||||
action='store_true',
|
)
|
||||||
help='If True, enable bias for LoRA adapters.')
|
lora_group.add_argument(
|
||||||
parser.add_argument('--max-loras',
|
'--enable-lora',
|
||||||
type=int,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=EngineArgs.max_loras,
|
help='If True, enable handling of LoRA adapters.')
|
||||||
help='Max number of LoRAs in a single batch.')
|
lora_group.add_argument('--enable-lora-bias',
|
||||||
parser.add_argument('--max-lora-rank',
|
**lora_kwargs["bias_enabled"])
|
||||||
type=int,
|
lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
|
||||||
default=EngineArgs.max_lora_rank,
|
lora_group.add_argument('--max-lora-rank',
|
||||||
help='Max LoRA rank.')
|
**lora_kwargs["max_lora_rank"])
|
||||||
parser.add_argument(
|
lora_group.add_argument('--lora-extra-vocab-size',
|
||||||
'--lora-extra-vocab-size',
|
**lora_kwargs["lora_extra_vocab_size"])
|
||||||
type=int,
|
lora_group.add_argument(
|
||||||
default=EngineArgs.lora_extra_vocab_size,
|
|
||||||
help=('Maximum size of extra vocabulary that can be '
|
|
||||||
'present in a LoRA adapter (added to the base '
|
|
||||||
'model vocabulary).'))
|
|
||||||
parser.add_argument(
|
|
||||||
'--lora-dtype',
|
'--lora-dtype',
|
||||||
type=str,
|
**lora_kwargs["lora_dtype"],
|
||||||
default=EngineArgs.lora_dtype,
|
)
|
||||||
choices=['auto', 'float16', 'bfloat16'],
|
lora_group.add_argument('--long-lora-scaling-factors',
|
||||||
help=('Data type for LoRA. If auto, will default to '
|
**lora_kwargs["long_lora_scaling_factors"])
|
||||||
'base model dtype.'))
|
lora_group.add_argument('--max-cpu-loras',
|
||||||
parser.add_argument(
|
**lora_kwargs["max_cpu_loras"])
|
||||||
'--long-lora-scaling-factors',
|
lora_group.add_argument('--fully-sharded-loras',
|
||||||
type=optional_str,
|
**lora_kwargs["fully_sharded_loras"])
|
||||||
default=EngineArgs.long_lora_scaling_factors,
|
|
||||||
help=('Specify multiple scaling factors (which can '
|
# PromptAdapter related configs
|
||||||
'be different from base model scaling factor '
|
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
|
||||||
'- see eg. Long LoRA) to allow for multiple '
|
prompt_adapter_group = parser.add_argument_group(
|
||||||
'LoRA adapters trained with those scaling '
|
title="PromptAdapterConfig",
|
||||||
'factors to be used at the same time. If not '
|
description=PromptAdapterConfig.__doc__,
|
||||||
'specified, only adapters trained with the '
|
)
|
||||||
'base model scaling factor are allowed.'))
|
prompt_adapter_group.add_argument(
|
||||||
parser.add_argument(
|
'--enable-prompt-adapter',
|
||||||
'--max-cpu-loras',
|
action=argparse.BooleanOptionalAction,
|
||||||
type=int,
|
help='If True, enable handling of PromptAdapters.')
|
||||||
default=EngineArgs.max_cpu_loras,
|
prompt_adapter_group.add_argument(
|
||||||
help=('Maximum number of LoRAs to store in CPU memory. '
|
'--max-prompt-adapters',
|
||||||
'Must be >= than max_loras.'))
|
**prompt_adapter_kwargs["max_prompt_adapters"])
|
||||||
parser.add_argument(
|
prompt_adapter_group.add_argument(
|
||||||
'--fully-sharded-loras',
|
'--max-prompt-adapter-token',
|
||||||
action='store_true',
|
**prompt_adapter_kwargs["max_prompt_adapter_token"])
|
||||||
help=('By default, only half of the LoRA computation is '
|
|
||||||
'sharded with tensor parallelism. '
|
|
||||||
'Enabling this will use the fully sharded layers. '
|
|
||||||
'At high sequence length, max rank or '
|
|
||||||
'tensor parallel size, this is likely faster.'))
|
|
||||||
parser.add_argument('--enable-prompt-adapter',
|
|
||||||
action='store_true',
|
|
||||||
help='If True, enable handling of PromptAdapters.')
|
|
||||||
parser.add_argument('--max-prompt-adapters',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.max_prompt_adapters,
|
|
||||||
help='Max number of PromptAdapters in a batch.')
|
|
||||||
parser.add_argument('--max-prompt-adapter-token',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.max_prompt_adapter_token,
|
|
||||||
help='Max number of PromptAdapters tokens')
|
|
||||||
|
|
||||||
# Device arguments
|
# Device arguments
|
||||||
device_kwargs = get_kwargs(DeviceConfig)
|
device_kwargs = get_kwargs(DeviceConfig)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user