[Bugfix] Override dunder methods of placeholder modules (#11882)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-09 17:02:53 +08:00 committed by GitHub
parent 310aca88c9
commit 0bd1ff4346
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 220 additions and 16 deletions

View File

@ -7,9 +7,9 @@ import pytest
import torch
from vllm_test_utils import monitor
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
StoreBoolean, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, supports_kw)
from .utils import error_on_warning, fork_new_process_for_each_test
@ -323,3 +323,44 @@ def test_memory_profiling():
del weights
lib.cudaFree(handle1)
lib.cudaFree(handle2)
def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")
def build_ctx():
return pytest.raises(ModuleNotFoundError,
match="No module named")
with build_ctx():
int(placeholder)
with build_ctx():
placeholder()
with build_ctx():
_ = placeholder.some_attr
with build_ctx():
# Test conflict with internal __name attribute
_ = placeholder.name
# OK to print the placeholder or use it in a f-string
_ = repr(placeholder)
_ = str(placeholder)
# No error yet; only error when it is used downstream
placeholder_attr = placeholder.placeholder_attr("attr")
with build_ctx():
int(placeholder_attr)
with build_ctx():
placeholder_attr()
with build_ctx():
_ = placeholder_attr.some_attr
with build_ctx():
# Test conflict with internal __module attribute
_ = placeholder_attr.module

View File

@ -46,7 +46,7 @@ import zmq
import zmq.asyncio
from packaging.version import Version
from torch.library import Library
from typing_extensions import ParamSpec, TypeIs, assert_never
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
@ -1627,24 +1627,183 @@ def get_vllm_optional_dependencies():
}
@dataclass(frozen=True)
class PlaceholderModule:
class _PlaceholderBase:
"""
Disallows downstream usage of placeholder modules.
We need to explicitly override each dunder method because
:meth:`__getattr__` is not called when they are accessed.
See also:
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
"""
def __getattr__(self, key: str) -> Never:
"""
The main class should implement this to throw an error
for attribute accesses representing downstream usage.
"""
raise NotImplementedError
# [Basic customization]
def __lt__(self, other: object):
return self.__getattr__("__lt__")
def __le__(self, other: object):
return self.__getattr__("__le__")
def __eq__(self, other: object):
return self.__getattr__("__eq__")
def __ne__(self, other: object):
return self.__getattr__("__ne__")
def __gt__(self, other: object):
return self.__getattr__("__gt__")
def __ge__(self, other: object):
return self.__getattr__("__ge__")
def __hash__(self):
return self.__getattr__("__hash__")
def __bool__(self):
return self.__getattr__("__bool__")
# [Callable objects]
def __call__(self, *args: object, **kwargs: object):
return self.__getattr__("__call__")
# [Container types]
def __len__(self):
return self.__getattr__("__len__")
def __getitem__(self, key: object):
return self.__getattr__("__getitem__")
def __setitem__(self, key: object, value: object):
return self.__getattr__("__setitem__")
def __delitem__(self, key: object):
return self.__getattr__("__delitem__")
# __missing__ is optional according to __getitem__ specification,
# so it is skipped
# __iter__ and __reversed__ have a default implementation
# based on __len__ and __getitem__, so they are skipped.
# [Numeric Types]
def __add__(self, other: object):
return self.__getattr__("__add__")
def __sub__(self, other: object):
return self.__getattr__("__sub__")
def __mul__(self, other: object):
return self.__getattr__("__mul__")
def __matmul__(self, other: object):
return self.__getattr__("__matmul__")
def __truediv__(self, other: object):
return self.__getattr__("__truediv__")
def __floordiv__(self, other: object):
return self.__getattr__("__floordiv__")
def __mod__(self, other: object):
return self.__getattr__("__mod__")
def __divmod__(self, other: object):
return self.__getattr__("__divmod__")
def __pow__(self, other: object, modulo: object = ...):
return self.__getattr__("__pow__")
def __lshift__(self, other: object):
return self.__getattr__("__lshift__")
def __rshift__(self, other: object):
return self.__getattr__("__rshift__")
def __and__(self, other: object):
return self.__getattr__("__and__")
def __xor__(self, other: object):
return self.__getattr__("__xor__")
def __or__(self, other: object):
return self.__getattr__("__or__")
# r* and i* methods have lower priority than
# the methods for left operand so they are skipped
def __neg__(self):
return self.__getattr__("__neg__")
def __pos__(self):
return self.__getattr__("__pos__")
def __abs__(self):
return self.__getattr__("__abs__")
def __invert__(self):
return self.__getattr__("__invert__")
# __complex__, __int__ and __float__ have a default implementation
# based on __index__, so they are skipped.
def __index__(self):
return self.__getattr__("__index__")
def __round__(self, ndigits: object = ...):
return self.__getattr__("__round__")
def __trunc__(self):
return self.__getattr__("__trunc__")
def __floor__(self):
return self.__getattr__("__floor__")
def __ceil__(self):
return self.__getattr__("__ceil__")
# [Context managers]
def __enter__(self):
return self.__getattr__("__enter__")
def __exit__(self, *args: object, **kwargs: object):
return self.__getattr__("__exit__")
class PlaceholderModule(_PlaceholderBase):
"""
A placeholder object to use when a module does not exist.
This enables more informative errors when trying to access attributes
of a module that does not exists.
"""
name: str
def __init__(self, name: str) -> None:
super().__init__()
# Apply name mangling to avoid conflicting with module attributes
self.__name = name
def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self, attr_path)
def __getattr__(self, key: str):
name = self.name
name = self.__name
try:
importlib.import_module(self.name)
importlib.import_module(name)
except ImportError as exc:
for extra, names in get_vllm_optional_dependencies().items():
if name in names:
@ -1657,17 +1816,21 @@ class PlaceholderModule:
"when the original module can be imported")
@dataclass(frozen=True)
class _PlaceholderModuleAttr:
module: PlaceholderModule
attr_path: str
class _PlaceholderModuleAttr(_PlaceholderBase):
def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
super().__init__()
# Apply name mangling to avoid conflicting with module attributes
self.__module = module
self.__attr_path = attr_path
def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self.module,
f"{self.attr_path}.{attr_path}")
return _PlaceholderModuleAttr(self.__module,
f"{self.__attr_path}.{attr_path}")
def __getattr__(self, key: str):
getattr(self.module, f"{self.attr_path}.{key}")
getattr(self.__module, f"{self.__attr_path}.{key}")
raise AssertionError("PlaceholderModule should not be used "
"when the original module can be imported")