diff --git a/vllm/config.py b/vllm/config.py index 1100e1077401..34952279c9d1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -15,7 +15,7 @@ from collections.abc import Mapping from contextlib import contextmanager from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, replace) -from functools import cached_property +from functools import cached_property, lru_cache from importlib.util import find_spec from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, Protocol, TypeVar, Union, cast, get_args) @@ -5123,6 +5123,14 @@ def set_current_vllm_config(vllm_config: VllmConfig, finally: _current_vllm_config = old_vllm_config _current_prefix = old_prefix + # Clear the compilation config cache when context changes + get_cached_compilation_config.cache_clear() + + +@lru_cache(maxsize=1) +def get_cached_compilation_config(): + """Cache config to avoid repeated calls to get_current_vllm_config()""" + return get_current_vllm_config().compilation_config def get_current_vllm_config() -> VllmConfig: diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index f6e79cd676f8..6b5a107396c9 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -5,7 +5,7 @@ from typing import Optional import torch.nn as nn -from vllm.config import get_current_vllm_config +from vllm.config import get_cached_compilation_config from vllm.logger import init_logger from vllm.platforms import current_platform @@ -86,7 +86,7 @@ class CustomOp(nn.Module): def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - compilation_config = get_current_vllm_config().compilation_config + compilation_config = get_cached_compilation_config() enabled = self.enabled() if enabled: compilation_config.enabled_custom_ops.update([self.__class__.name]) @@ -115,7 +115,7 @@ class CustomOp(nn.Module): @classmethod def enabled(cls) -> bool: # if no name, then it was not registered - compilation_config = get_current_vllm_config().compilation_config + compilation_config = get_cached_compilation_config() custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): logger.warning_once( @@ -138,7 +138,7 @@ class CustomOp(nn.Module): Specifying 'all' or 'none' in custom_op takes precedence. """ from vllm.config import CompilationLevel - compilation_config = get_current_vllm_config().compilation_config + compilation_config = get_cached_compilation_config() default_on = (compilation_config.level < CompilationLevel.PIECEWISE or not compilation_config.use_inductor) count_none = compilation_config.custom_ops.count("none")