mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-29 05:39:17 +08:00
[Bugfix] Override dunder methods of placeholder modules (#11882)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
310aca88c9
commit
0bd1ff4346
@ -7,9 +7,9 @@ import pytest
|
||||
import torch
|
||||
from vllm_test_utils import monitor
|
||||
|
||||
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
|
||||
get_open_port, memory_profiling, merge_async_iterators,
|
||||
supports_kw)
|
||||
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
|
||||
StoreBoolean, deprecate_kwargs, get_open_port,
|
||||
memory_profiling, merge_async_iterators, supports_kw)
|
||||
|
||||
from .utils import error_on_warning, fork_new_process_for_each_test
|
||||
|
||||
@ -323,3 +323,44 @@ def test_memory_profiling():
|
||||
del weights
|
||||
lib.cudaFree(handle1)
|
||||
lib.cudaFree(handle2)
|
||||
|
||||
|
||||
def test_placeholder_module_error_handling():
|
||||
placeholder = PlaceholderModule("placeholder_1234")
|
||||
|
||||
def build_ctx():
|
||||
return pytest.raises(ModuleNotFoundError,
|
||||
match="No module named")
|
||||
|
||||
with build_ctx():
|
||||
int(placeholder)
|
||||
|
||||
with build_ctx():
|
||||
placeholder()
|
||||
|
||||
with build_ctx():
|
||||
_ = placeholder.some_attr
|
||||
|
||||
with build_ctx():
|
||||
# Test conflict with internal __name attribute
|
||||
_ = placeholder.name
|
||||
|
||||
# OK to print the placeholder or use it in a f-string
|
||||
_ = repr(placeholder)
|
||||
_ = str(placeholder)
|
||||
|
||||
# No error yet; only error when it is used downstream
|
||||
placeholder_attr = placeholder.placeholder_attr("attr")
|
||||
|
||||
with build_ctx():
|
||||
int(placeholder_attr)
|
||||
|
||||
with build_ctx():
|
||||
placeholder_attr()
|
||||
|
||||
with build_ctx():
|
||||
_ = placeholder_attr.some_attr
|
||||
|
||||
with build_ctx():
|
||||
# Test conflict with internal __module attribute
|
||||
_ = placeholder_attr.module
|
||||
|
||||
189
vllm/utils.py
189
vllm/utils.py
@ -46,7 +46,7 @@ import zmq
|
||||
import zmq.asyncio
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from typing_extensions import ParamSpec, TypeIs, assert_never
|
||||
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
@ -1627,24 +1627,183 @@ def get_vllm_optional_dependencies():
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlaceholderModule:
|
||||
class _PlaceholderBase:
|
||||
"""
|
||||
Disallows downstream usage of placeholder modules.
|
||||
|
||||
We need to explicitly override each dunder method because
|
||||
:meth:`__getattr__` is not called when they are accessed.
|
||||
|
||||
See also:
|
||||
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
|
||||
"""
|
||||
|
||||
def __getattr__(self, key: str) -> Never:
|
||||
"""
|
||||
The main class should implement this to throw an error
|
||||
for attribute accesses representing downstream usage.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# [Basic customization]
|
||||
|
||||
def __lt__(self, other: object):
|
||||
return self.__getattr__("__lt__")
|
||||
|
||||
def __le__(self, other: object):
|
||||
return self.__getattr__("__le__")
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return self.__getattr__("__eq__")
|
||||
|
||||
def __ne__(self, other: object):
|
||||
return self.__getattr__("__ne__")
|
||||
|
||||
def __gt__(self, other: object):
|
||||
return self.__getattr__("__gt__")
|
||||
|
||||
def __ge__(self, other: object):
|
||||
return self.__getattr__("__ge__")
|
||||
|
||||
def __hash__(self):
|
||||
return self.__getattr__("__hash__")
|
||||
|
||||
def __bool__(self):
|
||||
return self.__getattr__("__bool__")
|
||||
|
||||
# [Callable objects]
|
||||
|
||||
def __call__(self, *args: object, **kwargs: object):
|
||||
return self.__getattr__("__call__")
|
||||
|
||||
# [Container types]
|
||||
|
||||
def __len__(self):
|
||||
return self.__getattr__("__len__")
|
||||
|
||||
def __getitem__(self, key: object):
|
||||
return self.__getattr__("__getitem__")
|
||||
|
||||
def __setitem__(self, key: object, value: object):
|
||||
return self.__getattr__("__setitem__")
|
||||
|
||||
def __delitem__(self, key: object):
|
||||
return self.__getattr__("__delitem__")
|
||||
|
||||
# __missing__ is optional according to __getitem__ specification,
|
||||
# so it is skipped
|
||||
|
||||
# __iter__ and __reversed__ have a default implementation
|
||||
# based on __len__ and __getitem__, so they are skipped.
|
||||
|
||||
# [Numeric Types]
|
||||
|
||||
def __add__(self, other: object):
|
||||
return self.__getattr__("__add__")
|
||||
|
||||
def __sub__(self, other: object):
|
||||
return self.__getattr__("__sub__")
|
||||
|
||||
def __mul__(self, other: object):
|
||||
return self.__getattr__("__mul__")
|
||||
|
||||
def __matmul__(self, other: object):
|
||||
return self.__getattr__("__matmul__")
|
||||
|
||||
def __truediv__(self, other: object):
|
||||
return self.__getattr__("__truediv__")
|
||||
|
||||
def __floordiv__(self, other: object):
|
||||
return self.__getattr__("__floordiv__")
|
||||
|
||||
def __mod__(self, other: object):
|
||||
return self.__getattr__("__mod__")
|
||||
|
||||
def __divmod__(self, other: object):
|
||||
return self.__getattr__("__divmod__")
|
||||
|
||||
def __pow__(self, other: object, modulo: object = ...):
|
||||
return self.__getattr__("__pow__")
|
||||
|
||||
def __lshift__(self, other: object):
|
||||
return self.__getattr__("__lshift__")
|
||||
|
||||
def __rshift__(self, other: object):
|
||||
return self.__getattr__("__rshift__")
|
||||
|
||||
def __and__(self, other: object):
|
||||
return self.__getattr__("__and__")
|
||||
|
||||
def __xor__(self, other: object):
|
||||
return self.__getattr__("__xor__")
|
||||
|
||||
def __or__(self, other: object):
|
||||
return self.__getattr__("__or__")
|
||||
|
||||
# r* and i* methods have lower priority than
|
||||
# the methods for left operand so they are skipped
|
||||
|
||||
def __neg__(self):
|
||||
return self.__getattr__("__neg__")
|
||||
|
||||
def __pos__(self):
|
||||
return self.__getattr__("__pos__")
|
||||
|
||||
def __abs__(self):
|
||||
return self.__getattr__("__abs__")
|
||||
|
||||
def __invert__(self):
|
||||
return self.__getattr__("__invert__")
|
||||
|
||||
# __complex__, __int__ and __float__ have a default implementation
|
||||
# based on __index__, so they are skipped.
|
||||
|
||||
def __index__(self):
|
||||
return self.__getattr__("__index__")
|
||||
|
||||
def __round__(self, ndigits: object = ...):
|
||||
return self.__getattr__("__round__")
|
||||
|
||||
def __trunc__(self):
|
||||
return self.__getattr__("__trunc__")
|
||||
|
||||
def __floor__(self):
|
||||
return self.__getattr__("__floor__")
|
||||
|
||||
def __ceil__(self):
|
||||
return self.__getattr__("__ceil__")
|
||||
|
||||
# [Context managers]
|
||||
|
||||
def __enter__(self):
|
||||
return self.__getattr__("__enter__")
|
||||
|
||||
def __exit__(self, *args: object, **kwargs: object):
|
||||
return self.__getattr__("__exit__")
|
||||
|
||||
|
||||
class PlaceholderModule(_PlaceholderBase):
|
||||
"""
|
||||
A placeholder object to use when a module does not exist.
|
||||
|
||||
This enables more informative errors when trying to access attributes
|
||||
of a module that does not exists.
|
||||
"""
|
||||
name: str
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Apply name mangling to avoid conflicting with module attributes
|
||||
self.__name = name
|
||||
|
||||
def placeholder_attr(self, attr_path: str):
|
||||
return _PlaceholderModuleAttr(self, attr_path)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
name = self.name
|
||||
name = self.__name
|
||||
|
||||
try:
|
||||
importlib.import_module(self.name)
|
||||
importlib.import_module(name)
|
||||
except ImportError as exc:
|
||||
for extra, names in get_vllm_optional_dependencies().items():
|
||||
if name in names:
|
||||
@ -1657,17 +1816,21 @@ class PlaceholderModule:
|
||||
"when the original module can be imported")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PlaceholderModuleAttr:
|
||||
module: PlaceholderModule
|
||||
attr_path: str
|
||||
class _PlaceholderModuleAttr(_PlaceholderBase):
|
||||
|
||||
def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Apply name mangling to avoid conflicting with module attributes
|
||||
self.__module = module
|
||||
self.__attr_path = attr_path
|
||||
|
||||
def placeholder_attr(self, attr_path: str):
|
||||
return _PlaceholderModuleAttr(self.module,
|
||||
f"{self.attr_path}.{attr_path}")
|
||||
return _PlaceholderModuleAttr(self.__module,
|
||||
f"{self.__attr_path}.{attr_path}")
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
getattr(self.module, f"{self.attr_path}.{key}")
|
||||
getattr(self.__module, f"{self.__attr_path}.{key}")
|
||||
|
||||
raise AssertionError("PlaceholderModule should not be used "
|
||||
"when the original module can be imported")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user