[LoRA] Change lora_tokenizers capacity (#10796)

Signed-off-by: Xin Yang <xyang19@gmail.com>
This commit is contained in:
Xin Yang 2024-12-04 09:40:16 -08:00 committed by GitHub
parent c92acb9693
commit 01d079fd8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 31 additions and 10 deletions

View File

@ -17,6 +17,7 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
max_loras=1,
max_input_length=None,
)
lora_request = LoRARequest("1", 1, sql_lora_files)
@ -53,3 +54,22 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path):
lora_request = LoRARequest("1", 1, str(tmp_path))
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer
@pytest.mark.parametrize("enable_lora", [True, False])
@pytest.mark.parametrize("max_num_seqs", [1, 2])
@pytest.mark.parametrize("max_loras", [1, 2])
def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras):
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(None),
tokenizer_id="gpt2",
enable_lora=enable_lora,
max_num_seqs=max_num_seqs,
max_loras=max_loras,
max_input_length=None,
)
if enable_lora:
assert tokenizer_group.lora_tokenizers.capacity == max(
max_num_seqs, max_loras)
else:
assert tokenizer_group.lora_tokenizers.capacity == 0

View File

@ -620,7 +620,7 @@ class LLMEngine:
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
enable_lora=bool(self.lora_config))
lora_config=self.lora_config)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)

View File

@ -94,8 +94,7 @@ class MQLLMEngineClient(EngineClient):
model_config=self.model_config,
scheduler_config=engine_config.scheduler_config,
parallel_config=engine_config.parallel_config,
enable_lora=bool(engine_config.lora_config),
)
lora_config=engine_config.lora_config)
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)

View File

@ -1,7 +1,7 @@
from typing import Optional, Type
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
TokenizerPoolConfig)
from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, TokenizerPoolConfig)
from vllm.executor.ray_utils import ray
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
@ -16,10 +16,11 @@ else:
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
enable_lora: bool):
lora_config: LoRAConfig):
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=enable_lora,
enable_lora=bool(lora_config),
max_num_seqs=scheduler_config.max_num_seqs,
max_loras=lora_config.max_loras if lora_config else 0,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,

View File

@ -21,8 +21,9 @@ class TokenizerGroup(BaseTokenizerGroup):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[AnyTokenizer](
capacity=max_num_seqs if enable_lora else 0)
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],

View File

@ -51,7 +51,7 @@ class AsyncLLM(EngineClient):
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
enable_lora=bool(vllm_config.lora_config))
lora_config=vllm_config.lora_config)
self.tokenizer.ping()
# Request streams (map of request_id -> AsyncStream).

View File

@ -46,7 +46,7 @@ class LLMEngine:
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
enable_lora=bool(vllm_config.lora_config))
lora_config=vllm_config.lora_config)
self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests)