mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[CI/Build] Avoid downloading all HF files in RemoteOpenAIServer (#7836)
This commit is contained in:
parent
0b769992ec
commit
029c71de11
@ -11,13 +11,14 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
|
||||
|
||||
@ -60,39 +61,50 @@ class RemoteOpenAIServer:
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
cli_args: List[str],
|
||||
vllm_serve_args: List[str],
|
||||
*,
|
||||
env_dict: Optional[Dict[str, str]] = None,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: Optional[float] = None) -> None:
|
||||
if not model.startswith("/"):
|
||||
# download the model if it's not a local path
|
||||
# to exclude the model download time from the server start time
|
||||
snapshot_download(model)
|
||||
if auto_port:
|
||||
if "-p" in cli_args or "--port" in cli_args:
|
||||
raise ValueError("You have manually specified the port"
|
||||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||||
raise ValueError("You have manually specified the port "
|
||||
"when `auto_port=True`.")
|
||||
|
||||
cli_args = cli_args + ["--port", str(get_open_port())]
|
||||
# Don't mutate the input args
|
||||
vllm_serve_args = vllm_serve_args + [
|
||||
"--port", str(get_open_port())
|
||||
]
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args(cli_args)
|
||||
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
||||
self.host = str(args.host or 'localhost')
|
||||
self.port = int(args.port)
|
||||
|
||||
# download the model before starting the server to avoid timeout
|
||||
is_local = os.path.isdir(model)
|
||||
if not is_local:
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
dummy_loader = DefaultModelLoader(engine_config.load_config)
|
||||
dummy_loader._prepare_weights(engine_config.model_config.model,
|
||||
engine_config.model_config.revision,
|
||||
fall_back_to_pt=True)
|
||||
|
||||
env = os.environ.copy()
|
||||
# the current process might initialize cuda,
|
||||
# to be safe, we should use spawn method
|
||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args,
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr)
|
||||
self.proc = subprocess.Popen(
|
||||
["vllm", "serve", model, *vllm_serve_args],
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
max_wait_seconds = max_wait_seconds or 240
|
||||
self._wait_for_server(url=self.url_for("health"),
|
||||
timeout=max_wait_seconds)
|
||||
|
||||
@ -742,7 +742,7 @@ class EngineArgs:
|
||||
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
return engine_args
|
||||
|
||||
def create_engine_config(self, ) -> EngineConfig:
|
||||
def create_engine_config(self) -> EngineConfig:
|
||||
# gguf file needs a specific model loader and doesn't use hf_repo
|
||||
if self.model.endswith(".gguf"):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user