[Perf] Cache vllm.env.__getattr__ result to avoid recomputation (#26146)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang 2025-10-14 14:03:21 -07:00 committed by GitHub
parent b92ab3deda
commit 380f17527c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 84 additions and 2 deletions

View File

@ -6,7 +6,54 @@ from unittest.mock import patch
import pytest import pytest
from vllm.envs import env_list_with_choices, env_with_choices import vllm.envs as envs
from vllm.envs import (
enable_envs_cache,
env_list_with_choices,
env_with_choices,
environment_variables,
)
def test_getattr_without_cache(monkeypatch: pytest.MonkeyPatch):
assert envs.VLLM_HOST_IP == ""
assert envs.VLLM_PORT is None
monkeypatch.setenv("VLLM_HOST_IP", "1.1.1.1")
monkeypatch.setenv("VLLM_PORT", "1234")
assert envs.VLLM_HOST_IP == "1.1.1.1"
assert envs.VLLM_PORT == 1234
# __getattr__ is not decorated with functools.cache
assert not hasattr(envs.__getattr__, "cache_info")
def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_HOST_IP", "1.1.1.1")
monkeypatch.setenv("VLLM_PORT", "1234")
# __getattr__ is not decorated with functools.cache
assert not hasattr(envs.__getattr__, "cache_info")
# Enable envs cache and ignore ongoing environment changes
enable_envs_cache()
# __getattr__ is not decorated with functools.cache
assert hasattr(envs.__getattr__, "cache_info")
start_hits = envs.__getattr__.cache_info().hits
# 2 more hits due to VLLM_HOST_IP and VLLM_PORT accesses
assert envs.VLLM_HOST_IP == "1.1.1.1"
assert envs.VLLM_PORT == 1234
assert envs.__getattr__.cache_info().hits == start_hits + 2
# All environment variables are cached
for environment_variable in environment_variables:
envs.__getattr__(environment_variable)
assert envs.__getattr__.cache_info().hits == start_hits + 2 + len(
environment_variables
)
# Reset envs.__getattr__ back to none-cached version to
# avoid affecting other tests
envs.__getattr__ = envs.__getattr__.__wrapped__
class TestEnvWithChoices: class TestEnvWithChoices:

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import hashlib import hashlib
import json import json
import os import os
@ -1408,12 +1409,36 @@ environment_variables: dict[str, Callable[[], Any]] = {
def __getattr__(name: str): def __getattr__(name: str):
# lazy evaluation of environment variables """
Gets environment variables lazily.
NOTE: After enable_envs_cache() invocation (which triggered after service
initialization), all environment variables will be cached.
"""
if name in environment_variables: if name in environment_variables:
return environment_variables[name]() return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def enable_envs_cache() -> None:
"""
Enables caching of environment variables. This is useful for performance
reasons, as it avoids the need to re-evaluate environment variables on
every call.
NOTE: Currently, it's invoked after service initialization to reduce
runtime overhead. This also means that environment variables should NOT
be updated after the service is initialized.
"""
# Tag __getattr__ with functools.cache
global __getattr__
__getattr__ = functools.cache(__getattr__)
# Cache all environment variables
for key in environment_variables:
__getattr__(key)
def __dir__(): def __dir__():
return list(environment_variables.keys()) return list(environment_variables.keys())

View File

@ -20,6 +20,7 @@ import zmq
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.distributed.parallel_state import is_global_first_rank from vllm.distributed.parallel_state import is_global_first_rank
from vllm.envs import enable_envs_cache
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -601,6 +602,10 @@ class EngineCoreProc(EngineCore):
# If enable, attach GC debugger after static variable freeze. # If enable, attach GC debugger after static variable freeze.
maybe_attach_gc_debug_callback() maybe_attach_gc_debug_callback()
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
enable_envs_cache()
@contextmanager @contextmanager
def _perform_handshakes( def _perform_handshakes(
self, self,

View File

@ -33,6 +33,7 @@ from vllm.distributed.parallel_state import (
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
) )
from vllm.envs import enable_envs_cache
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import ( from vllm.utils import (
_maybe_force_spawn, _maybe_force_spawn,
@ -455,6 +456,10 @@ class WorkerProc:
# Load model # Load model
self.worker.load_model() self.worker.load_model()
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
enable_envs_cache()
@staticmethod @staticmethod
def make_worker_process( def make_worker_process(
vllm_config: VllmConfig, vllm_config: VllmConfig,