mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 04:04:57 +08:00
[misc] add hint for AttributeError (#5462)
This commit is contained in:
parent
51602eefd3
commit
622d45128c
@ -1,13 +1,16 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import functools
|
||||||
from typing import List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vllm._C
|
import vllm._C
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
from vllm.logger import init_logger
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
logger.warning("Failed to import from vllm._C with %r", e)
|
logger.warning("Failed to import from vllm._C with %r", e)
|
||||||
|
|
||||||
with contextlib.suppress(ImportError):
|
with contextlib.suppress(ImportError):
|
||||||
@ -23,6 +26,25 @@ def is_custom_op_supported(op_name: str) -> bool:
|
|||||||
return op is not None
|
return op is not None
|
||||||
|
|
||||||
|
|
||||||
|
def hint_on_error(fn):
|
||||||
|
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
except AttributeError as e:
|
||||||
|
msg = (
|
||||||
|
"Error in calling custom op %s: %s\n"
|
||||||
|
"Possibly you have built or installed an obsolete version of vllm.\n"
|
||||||
|
"Please try a clean build and install of vllm,"
|
||||||
|
"or remove old built files such as vllm/*cpython*.so and build/ ."
|
||||||
|
)
|
||||||
|
logger.error(msg, fn.__name__, e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
# activation ops
|
# activation ops
|
||||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
torch.ops._C.silu_and_mul(out, x)
|
torch.ops._C.silu_and_mul(out, x)
|
||||||
@ -459,3 +481,25 @@ def dispatch_bgmv_low_level(
|
|||||||
h_out,
|
h_out,
|
||||||
y_offset,
|
y_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
|
||||||
|
# TODO: remove this in v0.6.0
|
||||||
|
names_and_values = globals()
|
||||||
|
names_and_values_to_update = {}
|
||||||
|
# prepare variables to avoid dict size change during iteration
|
||||||
|
k, v, arg = None, None, None
|
||||||
|
fn_type = type(lambda x: x)
|
||||||
|
for k, v in names_and_values.items():
|
||||||
|
# find functions that are defined in this file and have torch.Tensor
|
||||||
|
# in their annotations. `arg == "torch.Tensor"` is used to handle
|
||||||
|
# the case when users use `import __annotations__` to turn type
|
||||||
|
# hints into strings.
|
||||||
|
if isinstance(v, fn_type) \
|
||||||
|
and v.__code__.co_filename == __file__ \
|
||||||
|
and any(arg is torch.Tensor or arg == "torch.Tensor"
|
||||||
|
for arg in v.__annotations__.values()):
|
||||||
|
names_and_values_to_update[k] = hint_on_error(v)
|
||||||
|
|
||||||
|
names_and_values.update(names_and_values_to_update)
|
||||||
|
del names_and_values_to_update, names_and_values, v, k, fn_type
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user