vllm/vllm/transformers_utils/tokenizer_group.py
Harry Mellor 8c946cecca
Update deprecated type hinting in vllm/transformers_utils (#18058)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-13 04:34:37 -07:00

120 lines
5.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
get_lora_tokenizer,
get_lora_tokenizer_async,
get_tokenizer)
from vllm.utils import LRUCache
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: list[int],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
max_input_length = (lora_request.long_lora_max_len
or self.max_input_length)
else:
max_input_length = self.max_input_length
if max_input_length is not None and input_length > max_input_length:
raise ValueError("Input too long.", input_length, max_input_length)
def encode(self,
prompt: str,
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> list[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
ret = encode_tokens(tokenizer,
prompt,
max_length=max_length,
truncation=truncation,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret
async def encode_async(
self,
prompt: str,
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> list[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
ret = encode_tokens(tokenizer,
prompt,
max_length=max_length,
truncation=truncation,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers[lora_request.lora_int_id]
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers[lora_request.lora_int_id]
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig]):
return TokenizerGroup(
tokenizer_id=model_config.tokenizer,
enable_lora=bool(lora_config),
max_num_seqs=scheduler_config.max_num_seqs,
max_loras=lora_config.max_loras if lora_config else 0,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=model_config.truncation_side)