[misc] improve model support check in another process (#9208)

This commit is contained in:
youkaichao 2024-10-09 21:58:27 -07:00 committed by GitHub
parent cf25b93bdd
commit de895f1697
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 31 deletions

View File

@ -4,6 +4,7 @@ sphinx-copybutton==0.5.2
myst-parser==2.0.0 myst-parser==2.0.0
sphinx-argparse==0.4.0 sphinx-argparse==0.4.0
msgspec msgspec
cloudpickle
# packages to install to build the documentation # packages to install to build the documentation
pydantic >= 2.8 pydantic >= 2.8

View File

@ -1,11 +1,12 @@
import importlib import importlib
import string import pickle
import subprocess import subprocess
import sys import sys
import uuid import tempfile
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import cloudpickle
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
@ -282,36 +283,28 @@ class ModelRegistry:
raise raise
valid_name_characters = string.ascii_letters + string.digits + "._" with tempfile.NamedTemporaryFile() as output_file:
if any(s not in valid_name_characters for s in mod_name): # `cloudpickle` allows pickling lambda functions directly
raise ValueError(f"Unsafe module name detected for {model_arch}") input_bytes = cloudpickle.dumps(
if any(s not in valid_name_characters for s in cls_name): (mod_name, cls_name, func, output_file.name))
raise ValueError(f"Unsafe class name detected for {model_arch}") # cannot use `sys.executable __file__` here because the script
if any(s not in valid_name_characters for s in func.__module__): # contains relative imports
raise ValueError(f"Unsafe module name detected for {func}") returned = subprocess.run(
if any(s not in valid_name_characters for s in func.__name__): [sys.executable, "-m", "vllm.model_executor.models.registry"],
raise ValueError(f"Unsafe class name detected for {func}") input=input_bytes,
capture_output=True)
err_id = uuid.uuid4() # check if the subprocess is successful
try:
stmts = ";".join([ returned.check_returncode()
f"from {mod_name} import {cls_name}", except Exception as e:
f"from {func.__module__} import {func.__name__}", # wrap raised exception to provide more information
f"assert {func.__name__}({cls_name}), '{err_id}'", raise RuntimeError(f"Error happened when testing "
]) f"model support for{mod_name}.{cls_name}:\n"
f"{returned.stderr.decode()}") from e
result = subprocess.run([sys.executable, "-c", stmts], with open(output_file.name, "rb") as f:
capture_output=True) result = pickle.load(f)
return result
if result.returncode != 0:
err_lines = [line.decode() for line in result.stderr.splitlines()]
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
err_str = "\n".join(err_lines)
raise RuntimeError(
"An unexpected error occurred while importing the model in "
f"another process. Error log:\n{err_str}")
return result.returncode == 0
@staticmethod @staticmethod
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool: def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
@ -364,3 +357,13 @@ class ModelRegistry:
default=False) default=False)
return any(is_pp(arch) for arch in architectures) return any(is_pp(arch) for arch in architectures)
if __name__ == "__main__":
(mod_name, cls_name, func,
output_file) = pickle.loads(sys.stdin.buffer.read())
mod = importlib.import_module(mod_name)
klass = getattr(mod, cls_name)
result = func(klass)
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))