mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:35:24 +08:00
Optimize configuration access with LRU cache in custom ops (#22204)
Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com>
This commit is contained in:
parent
bd3db7f469
commit
4b3e4474d7
@ -15,7 +15,7 @@ from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
|
||||
replace)
|
||||
from functools import cached_property
|
||||
from functools import cached_property, lru_cache
|
||||
from importlib.util import find_spec
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
|
||||
Protocol, TypeVar, Union, cast, get_args)
|
||||
@ -5123,6 +5123,14 @@ def set_current_vllm_config(vllm_config: VllmConfig,
|
||||
finally:
|
||||
_current_vllm_config = old_vllm_config
|
||||
_current_prefix = old_prefix
|
||||
# Clear the compilation config cache when context changes
|
||||
get_cached_compilation_config.cache_clear()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_cached_compilation_config():
|
||||
"""Cache config to avoid repeated calls to get_current_vllm_config()"""
|
||||
return get_current_vllm_config().compilation_config
|
||||
|
||||
|
||||
def get_current_vllm_config() -> VllmConfig:
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import get_cached_compilation_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -86,7 +86,7 @@ class CustomOp(nn.Module):
|
||||
def dispatch_forward(self):
|
||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
compilation_config = get_cached_compilation_config()
|
||||
enabled = self.enabled()
|
||||
if enabled:
|
||||
compilation_config.enabled_custom_ops.update([self.__class__.name])
|
||||
@ -115,7 +115,7 @@ class CustomOp(nn.Module):
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
# if no name, then it was not registered
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
compilation_config = get_cached_compilation_config()
|
||||
custom_ops = compilation_config.custom_ops
|
||||
if not hasattr(cls, "name"):
|
||||
logger.warning_once(
|
||||
@ -138,7 +138,7 @@ class CustomOp(nn.Module):
|
||||
Specifying 'all' or 'none' in custom_op takes precedence.
|
||||
"""
|
||||
from vllm.config import CompilationLevel
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
compilation_config = get_cached_compilation_config()
|
||||
default_on = (compilation_config.level < CompilationLevel.PIECEWISE
|
||||
or not compilation_config.use_inductor)
|
||||
count_none = compilation_config.custom_ops.count("none")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user