Robert Shaw ed812a73fa
[ Frontend ] Multiprocessing for OpenAI Server with zeromq (#6883)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <joe@joerun.de>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-08-02 18:27:28 -07:00

54 lines
2.2 KiB
Python

from typing import Optional, Type
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
TokenizerPoolConfig)
from vllm.executor.ray_utils import ray
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
else:
RayTokenizerGroupPool = None # type: ignore
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
enable_lora: bool):
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=enable_lora,
max_num_seqs=scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision)
return get_tokenizer_group(parallel_config.tokenizer_pool_config,
**init_kwargs)
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup:
tokenizer_cls: Type[BaseTokenizerGroup]
if tokenizer_pool_config is None:
tokenizer_cls = TokenizerGroup
elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass(
tokenizer_pool_config.pool_type, BaseTokenizerGroup):
tokenizer_cls = tokenizer_pool_config.pool_type
elif tokenizer_pool_config.pool_type == "ray":
if RayTokenizerGroupPool is None:
raise ImportError(
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool.")
tokenizer_cls = RayTokenizerGroupPool
else:
raise ValueError(
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"]