mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:15:31 +08:00
[misc] improve model support check in another process (#9208)
This commit is contained in:
parent
cf25b93bdd
commit
de895f1697
@ -4,6 +4,7 @@ sphinx-copybutton==0.5.2
|
||||
myst-parser==2.0.0
|
||||
sphinx-argparse==0.4.0
|
||||
msgspec
|
||||
cloudpickle
|
||||
|
||||
# packages to install to build the documentation
|
||||
pydantic >= 2.8
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import importlib
|
||||
import string
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
import tempfile
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@ -282,36 +283,28 @@ class ModelRegistry:
|
||||
|
||||
raise
|
||||
|
||||
valid_name_characters = string.ascii_letters + string.digits + "._"
|
||||
if any(s not in valid_name_characters for s in mod_name):
|
||||
raise ValueError(f"Unsafe module name detected for {model_arch}")
|
||||
if any(s not in valid_name_characters for s in cls_name):
|
||||
raise ValueError(f"Unsafe class name detected for {model_arch}")
|
||||
if any(s not in valid_name_characters for s in func.__module__):
|
||||
raise ValueError(f"Unsafe module name detected for {func}")
|
||||
if any(s not in valid_name_characters for s in func.__name__):
|
||||
raise ValueError(f"Unsafe class name detected for {func}")
|
||||
with tempfile.NamedTemporaryFile() as output_file:
|
||||
# `cloudpickle` allows pickling lambda functions directly
|
||||
input_bytes = cloudpickle.dumps(
|
||||
(mod_name, cls_name, func, output_file.name))
|
||||
# cannot use `sys.executable __file__` here because the script
|
||||
# contains relative imports
|
||||
returned = subprocess.run(
|
||||
[sys.executable, "-m", "vllm.model_executor.models.registry"],
|
||||
input=input_bytes,
|
||||
capture_output=True)
|
||||
|
||||
err_id = uuid.uuid4()
|
||||
|
||||
stmts = ";".join([
|
||||
f"from {mod_name} import {cls_name}",
|
||||
f"from {func.__module__} import {func.__name__}",
|
||||
f"assert {func.__name__}({cls_name}), '{err_id}'",
|
||||
])
|
||||
|
||||
result = subprocess.run([sys.executable, "-c", stmts],
|
||||
capture_output=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
err_lines = [line.decode() for line in result.stderr.splitlines()]
|
||||
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
|
||||
err_str = "\n".join(err_lines)
|
||||
raise RuntimeError(
|
||||
"An unexpected error occurred while importing the model in "
|
||||
f"another process. Error log:\n{err_str}")
|
||||
|
||||
return result.returncode == 0
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
except Exception as e:
|
||||
# wrap raised exception to provide more information
|
||||
raise RuntimeError(f"Error happened when testing "
|
||||
f"model support for{mod_name}.{cls_name}:\n"
|
||||
f"{returned.stderr.decode()}") from e
|
||||
with open(output_file.name, "rb") as f:
|
||||
result = pickle.load(f)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
|
||||
@ -364,3 +357,13 @@ class ModelRegistry:
|
||||
default=False)
|
||||
|
||||
return any(is_pp(arch) for arch in architectures)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(mod_name, cls_name, func,
|
||||
output_file) = pickle.loads(sys.stdin.buffer.read())
|
||||
mod = importlib.import_module(mod_name)
|
||||
klass = getattr(mod, cls_name)
|
||||
result = func(klass)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(pickle.dumps(result))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user