mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 14:10:54 +08:00
[LoRA] Change lora_tokenizers capacity (#10796)
Signed-off-by: Xin Yang <xyang19@gmail.com>
This commit is contained in:
parent
c92acb9693
commit
01d079fd8e
@ -17,6 +17,7 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=True,
|
||||
max_num_seqs=1,
|
||||
max_loras=1,
|
||||
max_input_length=None,
|
||||
)
|
||||
lora_request = LoRARequest("1", 1, sql_lora_files)
|
||||
@ -53,3 +54,22 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path):
|
||||
lora_request = LoRARequest("1", 1, str(tmp_path))
|
||||
tokenizer = get_lora_tokenizer(lora_request)
|
||||
assert not tokenizer
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_lora", [True, False])
|
||||
@pytest.mark.parametrize("max_num_seqs", [1, 2])
|
||||
@pytest.mark.parametrize("max_loras", [1, 2])
|
||||
def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras):
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(None),
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=enable_lora,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_loras=max_loras,
|
||||
max_input_length=None,
|
||||
)
|
||||
if enable_lora:
|
||||
assert tokenizer_group.lora_tokenizers.capacity == max(
|
||||
max_num_seqs, max_loras)
|
||||
else:
|
||||
assert tokenizer_group.lora_tokenizers.capacity == 0
|
||||
|
||||
@ -620,7 +620,7 @@ class LLMEngine:
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
parallel_config=self.parallel_config,
|
||||
enable_lora=bool(self.lora_config))
|
||||
lora_config=self.lora_config)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
@ -94,8 +94,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
model_config=self.model_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
enable_lora=bool(engine_config.lora_config),
|
||||
)
|
||||
lora_config=engine_config.lora_config)
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
TokenizerPoolConfig)
|
||||
from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, TokenizerPoolConfig)
|
||||
from vllm.executor.ray_utils import ray
|
||||
|
||||
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
||||
@ -16,10 +16,11 @@ else:
|
||||
def init_tokenizer_from_configs(model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
enable_lora: bool):
|
||||
lora_config: LoRAConfig):
|
||||
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
|
||||
enable_lora=enable_lora,
|
||||
enable_lora=bool(lora_config),
|
||||
max_num_seqs=scheduler_config.max_num_seqs,
|
||||
max_loras=lora_config.max_loras if lora_config else 0,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
|
||||
@ -21,8 +21,9 @@ class TokenizerGroup(BaseTokenizerGroup):
|
||||
self.enable_lora = enable_lora
|
||||
self.max_input_length = max_input_length
|
||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||
max_loras = tokenizer_config.get("max_loras", 0)
|
||||
self.lora_tokenizers = LRUCache[AnyTokenizer](
|
||||
capacity=max_num_seqs if enable_lora else 0)
|
||||
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||
|
||||
@ -51,7 +51,7 @@ class AsyncLLM(EngineClient):
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
enable_lora=bool(vllm_config.lora_config))
|
||||
lora_config=vllm_config.lora_config)
|
||||
self.tokenizer.ping()
|
||||
|
||||
# Request streams (map of request_id -> AsyncStream).
|
||||
|
||||
@ -46,7 +46,7 @@ class LLMEngine:
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
enable_lora=bool(vllm_config.lora_config))
|
||||
lora_config=vllm_config.lora_config)
|
||||
self.tokenizer.ping()
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user