mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 01:49:19 +08:00
[torch.compile] add warning for unsupported models (#10622)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
7c2134beda
commit
65813781a2
@ -5,6 +5,7 @@ from contextlib import contextmanager
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class CompilationCounter:
|
class CompilationCounter:
|
||||||
|
num_models_seen: int = 0
|
||||||
num_graphs_seen: int = 0
|
num_graphs_seen: int = 0
|
||||||
# including the splitting ops
|
# including the splitting ops
|
||||||
num_piecewise_graphs_seen: int = 0
|
num_piecewise_graphs_seen: int = 0
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -130,6 +131,7 @@ def _support_torch_compile(cls: type,
|
|||||||
] or not supports_dynamo()
|
] or not supports_dynamo()
|
||||||
if self.do_not_compile:
|
if self.do_not_compile:
|
||||||
return
|
return
|
||||||
|
compilation_counter.num_models_seen += 1
|
||||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||||
self, compilation_level=vllm_config.compilation_config.level)
|
self, compilation_level=vllm_config.compilation_config.level)
|
||||||
|
|
||||||
|
|||||||
@ -80,6 +80,9 @@ def set_current_vllm_config(vllm_config: "VllmConfig"):
|
|||||||
"""
|
"""
|
||||||
global _current_vllm_config
|
global _current_vllm_config
|
||||||
old_vllm_config = _current_vllm_config
|
old_vllm_config = _current_vllm_config
|
||||||
|
from vllm.compilation.counter import compilation_counter
|
||||||
|
from vllm.config import CompilationLevel
|
||||||
|
num_models_seen = compilation_counter.num_models_seen
|
||||||
try:
|
try:
|
||||||
_current_vllm_config = vllm_config
|
_current_vllm_config = vllm_config
|
||||||
yield
|
yield
|
||||||
@ -88,6 +91,18 @@ def set_current_vllm_config(vllm_config: "VllmConfig"):
|
|||||||
vllm_config.compilation_config.enabled_custom_ops)
|
vllm_config.compilation_config.enabled_custom_ops)
|
||||||
logger.debug("disabled custom ops: %s",
|
logger.debug("disabled custom ops: %s",
|
||||||
vllm_config.compilation_config.disabled_custom_ops)
|
vllm_config.compilation_config.disabled_custom_ops)
|
||||||
|
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
||||||
|
and compilation_counter.num_models_seen == num_models_seen:
|
||||||
|
# If the model supports compilation,
|
||||||
|
# compilation_counter.num_models_seen should be increased
|
||||||
|
# by at least 1.
|
||||||
|
# If it is not increased, it means the model does not support
|
||||||
|
# compilation (does not have @support_torch_compile decorator).
|
||||||
|
logger.warning(
|
||||||
|
"`torch.compile` is turned on, but the model %s"
|
||||||
|
" does not support it. Please open an issue on GitHub"
|
||||||
|
"if you want it to be supported.",
|
||||||
|
vllm_config.model_config.model)
|
||||||
_current_vllm_config = old_vllm_config
|
_current_vllm_config = old_vllm_config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user