Optimize configuration access with LRU cache in custom ops (#22204)

Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com>
This commit is contained in:
ZiTian.Zhao 2025-08-05 12:43:24 +08:00 committed by GitHub
parent bd3db7f469
commit 4b3e4474d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 5 deletions

View File

@ -15,7 +15,7 @@ from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
replace)
from functools import cached_property
from functools import cached_property, lru_cache
from importlib.util import find_spec
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args)
@ -5123,6 +5123,14 @@ def set_current_vllm_config(vllm_config: VllmConfig,
finally:
_current_vllm_config = old_vllm_config
_current_prefix = old_prefix
# Clear the compilation config cache when context changes
get_cached_compilation_config.cache_clear()
@lru_cache(maxsize=1)
def get_cached_compilation_config():
"""Cache config to avoid repeated calls to get_current_vllm_config()"""
return get_current_vllm_config().compilation_config
def get_current_vllm_config() -> VllmConfig:

View File

@ -5,7 +5,7 @@ from typing import Optional
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.config import get_cached_compilation_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -86,7 +86,7 @@ class CustomOp(nn.Module):
def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
compilation_config = get_current_vllm_config().compilation_config
compilation_config = get_cached_compilation_config()
enabled = self.enabled()
if enabled:
compilation_config.enabled_custom_ops.update([self.__class__.name])
@ -115,7 +115,7 @@ class CustomOp(nn.Module):
@classmethod
def enabled(cls) -> bool:
# if no name, then it was not registered
compilation_config = get_current_vllm_config().compilation_config
compilation_config = get_cached_compilation_config()
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
logger.warning_once(
@ -138,7 +138,7 @@ class CustomOp(nn.Module):
Specifying 'all' or 'none' in custom_op takes precedence.
"""
from vllm.config import CompilationLevel
compilation_config = get_current_vllm_config().compilation_config
compilation_config = get_cached_compilation_config()
default_on = (compilation_config.level < CompilationLevel.PIECEWISE
or not compilation_config.use_inductor)
count_none = compilation_config.custom_ops.count("none")