[torch.compile] caching of config fields should be opt-out by default (#26468)

Signed-off-by: vnadathur <glvikramn@gmail.com>
Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com>
Signed-off-by: Srreyansh Sethi <srreyansh.sethi@gmail.com>
Signed-off-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com>
Co-authored-by: WorldExplored <srreyansh.sethi@gmail.com>
Co-authored-by: Srreyansh Sethi <107075589+worldexplored@users.noreply.github.com>
Co-authored-by: vnadathur <236933696+vnadathur@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
vnadathur 2025-11-19 06:13:54 -08:00 committed by GitHub
parent 2c8b9182b5
commit 1ffe934c8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 599 additions and 190 deletions

View File

@ -0,0 +1,166 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from enum import Enum
import pytest
from vllm.config.utils import get_hash_factors, hash_factors, normalize_value
# Helpers
def endswith_fqname(obj, suffix: str) -> bool:
# normalize_value(type) returns fully-qualified name
# Compare suffix to avoid brittle import paths.
out = normalize_value(obj)
return isinstance(out, str) and out.endswith(suffix)
def expected_path(p_str: str = ".") -> str:
import pathlib
p = pathlib.Path(p_str)
return p.expanduser().resolve().as_posix()
# Minimal dataclass to test get_hash_factors.
# Avoid importing heavy vLLM configs.
@dataclass
class SimpleConfig:
a: object
b: object | None = None
class DummyLogprobsMode(Enum):
RAW_LOGITS = "raw_logits"
def test_hash_factors_deterministic():
"""Test that hash_factors produces consistent SHA-256 hashes"""
factors = {"a": 1, "b": "test"}
hash1 = hash_factors(factors)
hash2 = hash_factors(factors)
assert hash1 == hash2
# Dict key insertion order should not affect the hash.
factors_reordered = {"b": "test", "a": 1}
assert hash_factors(factors_reordered) == hash1
assert len(hash1) == 64
assert all(c in "0123456789abcdef" for c in hash1)
@pytest.mark.parametrize(
"inp, expected",
[
(None, None),
(True, True),
(1, 1),
(1.0, 1.0),
("x", "x"),
(b"ab", "6162"),
(bytearray(b"ab"), "6162"),
([1, 2], (1, 2)),
({"b": 2, "a": 1}, (("a", 1), ("b", 2))),
],
)
def test_normalize_value_matrix(inp, expected):
"""Parametric input→expected normalization table."""
assert normalize_value(inp) == expected
def test_normalize_value_enum():
# Enums normalize to (module.QualName, value).
# DummyLogprobsMode uses a string payload.
out = normalize_value(DummyLogprobsMode.RAW_LOGITS)
assert isinstance(out, tuple)
assert out[0].endswith("DummyLogprobsMode")
# Expect string payload 'raw_logits'.
assert out[1] == "raw_logits"
def test_normalize_value_set_order_insensitive():
# Sets are unordered; normalize_value sorts elements for determinism.
assert normalize_value({3, 1, 2}) == normalize_value({1, 2, 3})
def test_normalize_value_path_normalization():
from pathlib import Path # local import to avoid global dependency
# Paths expand/resolve to absolute strings.
# Stabilizes hashing across working dirs.
assert normalize_value(Path(".")) == expected_path(".")
def test_normalize_value_uuid_and_to_json():
# Objects may normalize via uuid() or to_json_string().
class HasUUID:
def uuid(self):
return "test-uuid"
class ToJson:
def to_json_string(self):
return '{"x":1}'
assert normalize_value(HasUUID()) == "test-uuid"
assert normalize_value(ToJson()) == '{"x":1}'
@pytest.mark.parametrize(
"bad",
[
(lambda x: x),
(type("CallableInstance", (), {"__call__": lambda self: 0}))(),
(lambda: (lambda: 0))(), # nested function instance
],
)
def test_error_cases(bad):
"""Inputs expected to raise TypeError."""
# Reject functions/lambdas/callable instances
# to avoid under-hashing.
with pytest.raises(TypeError):
normalize_value(bad)
def test_enum_vs_int_disambiguation():
# int stays primitive
nf_int = normalize_value(1)
assert nf_int == 1
# enum becomes ("module.QualName", value)
nf_enum = normalize_value(DummyLogprobsMode.RAW_LOGITS)
assert isinstance(nf_enum, tuple) and len(nf_enum) == 2
enum_type, enum_val = nf_enum
assert enum_type.endswith(".DummyLogprobsMode")
assert enum_val == "raw_logits"
# Build factor dicts from configs with int vs enum
f_int = get_hash_factors(SimpleConfig(1), set())
f_enum = get_hash_factors(SimpleConfig(DummyLogprobsMode.RAW_LOGITS), set())
# The int case remains a primitive value
assert f_int["a"] == 1
# The enum case becomes a tagged tuple ("module.QualName", "raw_logits")
assert isinstance(f_enum["a"], tuple) and f_enum["a"][1] == "raw_logits"
# Factor dicts must differ so we don't collide primitives with Enums.
assert f_int != f_enum
# Hash digests must differ correspondingly
assert hash_factors(f_int) != hash_factors(f_enum)
# Hash functions produce stable hex strings
h_int = hash_factors(f_int)
h_enum = hash_factors(f_enum)
assert isinstance(h_int, str) and len(h_int) == 64
assert isinstance(h_enum, str) and len(h_enum) == 64
def test_classes_are_types():
"""Types normalize to FQNs; include real vLLM types."""
# Only classes allowed; functions/lambdas are rejected.
# Canonical form is the fully-qualified name.
assert isinstance(normalize_value(str), str)
class LocalDummy:
pass
assert endswith_fqname(LocalDummy, ".LocalDummy")

View File

@ -4,12 +4,14 @@
import ast import ast
import dataclasses import dataclasses
import hashlib import hashlib
import json
import operator import operator
import os import os
import pprint import pprint
import time import time
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial
from typing import Any from typing import Any
import torch import torch
@ -23,7 +25,9 @@ from vllm.compilation.partition_rules import (
should_split, should_split,
) )
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.utils import hash_factors
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logging_utils import lazy
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
@ -580,35 +584,47 @@ class VllmBackend:
def __call__( def __call__(
self, graph: fx.GraphModule, example_inputs self, graph: fx.GraphModule, example_inputs
) -> VllmSerializableFunction: ) -> VllmSerializableFunction:
from .caching import _compute_code_hash, compilation_config_hash_factors
vllm_config = self.vllm_config vllm_config = self.vllm_config
# Minimal hashing here with existing utilities, reused below.
env_factors = envs.compile_factors()
env_hash = hash_factors(env_factors)
# Compute config/compiler/code hashes once and reuse
config_hash = vllm_config.compute_hash()
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
forward_code_files = list(sorted(self.compilation_config.traced_files))
logger.debug(
"Traced files (to be considered for compilation cache):\n%s",
lazy(lambda: "\n".join(forward_code_files)),
)
hash_content = []
for filepath in forward_code_files:
hash_content.append(filepath)
if filepath == "<string>":
# This means the function was dynamically generated, with
# e.g. exec(). We can't actually check these.
continue
try:
with open(filepath) as f:
hash_content.append(f.read())
except Exception:
logger.warning("Failed to read file %s", filepath)
continue
code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
# Clear after consumption
self.compilation_config.traced_files.clear()
if not self.compilation_config.cache_dir: if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors # no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change, # that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled # the cache dir will be the same so that we can reuse the compiled
# graph. # graph.
factors = [env_hash, config_hash, code_hash, compiler_hash]
factors = compilation_config_hash_factors(vllm_config) # Use SHA-256 for cache key hashing to be consistent across
# 2. factors come from the code files that are traced by Dynamo ( # compute_hash functions. Truncate for a short cache dir name.
# it mainly summarizes how the model is used in forward pass) hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
code_hash = _compute_code_hash(self.compilation_config.traced_files)
self.compilation_config.traced_files.clear()
factors.append(code_hash)
# 3. compiler hash
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
factors.append(compiler_hash)
# combine all factors to generate the cache dir
hash_key = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
cache_dir = os.path.join( cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
"torch_compile_cache",
hash_key,
) )
self.compilation_config.cache_dir = cache_dir self.compilation_config.cache_dir = cache_dir
@ -621,6 +637,7 @@ class VllmBackend:
os.makedirs(local_cache_dir, exist_ok=True) os.makedirs(local_cache_dir, exist_ok=True)
self.compilation_config.local_cache_dir = local_cache_dir self.compilation_config.local_cache_dir = local_cache_dir
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled( disable_cache = not is_compile_cache_enabled(
self.compilation_config.inductor_compile_config self.compilation_config.inductor_compile_config
) )
@ -638,6 +655,50 @@ class VllmBackend:
local_cache_dir, disable_cache, self.prefix local_cache_dir, disable_cache, self.prefix
) )
# Reuses existing cache key
logger.debug(
"torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
env_hash,
config_hash,
compiler_hash,
code_hash,
local_cache_dir,
)
# Persist and log only hash-relevant factors together.
try:
logger.debug(
"Compile env factors (raw):\n%s\nVllm config hash: %s",
lazy(partial(pprint.pformat, env_factors, width=120)),
config_hash,
)
meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
if not os.path.exists(meta_path):
with open(meta_path, "w") as f:
json.dump(
{
"env": env_factors, # raw factors used for env_hash
"config_hash": config_hash,
"code_hash": code_hash,
"compiler_hash": compiler_hash,
},
f,
indent=2,
sort_keys=True,
)
except Exception:
# Best-effort only; metadata write failures are non-fatal.
logger.warning(
(
"Could not write compile cache metadata at %s; continuing without "
"metadata. Compiled cache remains valid; diagnostics may be "
"limited."
),
local_cache_dir,
exc_info=True,
)
# when dynamo calls the backend, it means the bytecode # when dynamo calls the backend, it means the bytecode
# transform and analysis are done # transform and analysis are done
compilation_counter.num_graphs_seen += 1 compilation_counter.num_graphs_seen += 1

View File

@ -127,7 +127,7 @@ class PostGradPassManager(CustomGraphPass):
affects compilation caching. Its uuid depends on the UUIDs of all affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info. dependent passes and the pass config. See InductorPass for more info.
""" """
state = {"pass_config": self.pass_config.uuid(), "passes": []} state = {"pass_config": self.pass_config.compute_hash(), "passes": []}
for pass_ in self.passes: for pass_ in self.passes:
state["passes"].append(pass_.uuid()) state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid()) state["passes"].append(self.fix_functionalization.uuid())

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import field from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
@ -160,13 +159,29 @@ class CacheConfig:
excluding anything before input ids/embeddings and after excluding anything before input ids/embeddings and after
the final hidden states. the final hidden states.
""" """
factors: list[Any] = [] ignored_factors = {
factors.append(self.cache_dtype) # Runtime/derived knobs that don't affect compiled graph shape
factors.append(self.mamba_cache_dtype) "gpu_memory_utilization",
factors.append(self.mamba_ssm_cache_dtype) "swap_space",
# `cpu_offload_gb` does not use `torch.compile` yet. "is_attention_free",
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() "num_gpu_blocks_override",
return hash_str "enable_prefix_caching",
"prefix_caching_hash_algo",
# `cpu_offload_gb` does not use `torch.compile` yet.
"cpu_offload_gb",
"cpu_kvcache_space_bytes",
"mamba_page_size_padded",
# Post-init/derived counters
"num_gpu_blocks",
"num_cpu_blocks",
# WIP feature toggle not impacting compiled graph shape
"kv_sharing_fast_prefill",
}
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)
def metrics_info(self): def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus # convert cache_config to dict(key: str, value: str) for prometheus

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum import enum
import hashlib
from collections import Counter from collections import Counter
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict, field from dataclasses import asdict, field
@ -160,7 +159,7 @@ class PassConfig:
current_platform.get_device_capability().to_int(), {} current_platform.get_device_capability().to_int(), {}
) )
def uuid(self): def compute_hash(self) -> str:
""" """
Produces a hash unique to the pass configuration. Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash. Any new fields that affect compilation should be added to the hash.
@ -506,28 +505,33 @@ class CompilationConfig:
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs Provide a hash that uniquely identifies all the configs
that affect the structure of the computation that affect the structure of the computation
graph from input ids/embeddings to the final hidden states, graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after excluding anything before input ids/embeddings and after
the final hidden states. the final hidden states.
""" """
factors: list[Any] = [] # Opt-out: default-include declared fields; keep a tiny exclude set;
factors.append(self.mode) # normalize types; keep SHA-256. For nested opaque configs, include a
factors.append(self.backend) # stable identifier (e.g., pass_config.compute_hash()) instead of object id.
factors.append(self.custom_ops)
factors.append(self.splitting_ops) ignored_factors = {
factors.append(self.use_inductor) # Paths/dirs and runtime/metrics that dont affect compiled graph
factors.append(self.use_inductor_graph_partition) "debug_dump_path",
factors.append(self.inductor_compile_config) "cache_dir",
factors.append(self.inductor_passes) "local_cache_dir",
factors.append(self.pass_config.uuid()) "bs_to_padded_graph_size",
factors.append(self.compile_cache_save_format) "traced_files",
return hashlib.sha256(str(factors).encode()).hexdigest() "compilation_time",
"static_forward_context",
"pass_config", # handled separately below
}
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
factors["pass_config"] = self.pass_config.compute_hash()
return hash_factors(factors)
def __repr__(self) -> str: def __repr__(self) -> str:
exclude = { exclude = {

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import json
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
from dataclasses import InitVar, field from dataclasses import InitVar, field
@ -18,7 +16,7 @@ import vllm.envs as envs
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config, getattr_iter from vllm.config.utils import config, getattr_iter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
@ -324,50 +322,50 @@ class ModelConfig:
excluding anything before input ids/embeddings and after excluding anything before input ids/embeddings and after
the final hidden states. the final hidden states.
""" """
factors: list[Any] = [] ignored_factors = {
factors.append(self.model) "runner",
factors.append(self.dtype) "convert",
factors.append(self.quantization) "task",
factors.append(self.revision) "tokenizer",
factors.append(self.code_revision) "tokenizer_mode",
factors.append(self.max_model_len) "seed",
factors.append(self.max_logprobs) "hf_config_path",
factors.append(self.disable_sliding_window) "allowed_local_media_path",
factors.append(self.trust_remote_code) "allowed_media_domains",
factors.append(self.generation_config) "tokenizer_revision",
factors.append(self.model_impl) "spec_target_max_model_len",
factors.append(self.override_generation_config) "enforce_eager",
factors.append(self.video_pruning_rate) "logprobs_mode",
factors.append(self.enable_prompt_embeds) "disable_cascade_attn",
"skip_tokenizer_init",
"enable_prompt_embeds",
"served_model_name",
"config_format",
"hf_token",
"hf_overrides",
"logits_processor_pattern",
"enable_sleep_mode",
"override_attention_dtype",
"logits_processors",
"io_processor_plugin",
"pooler_config",
"override_pooler_config",
"multimodal_config",
"limit_mm_per_prompt",
"media_io_kwargs",
"mm_processor_kwargs",
"mm_processor_cache_gb",
"mm_processor_cache_type",
"mm_shm_cache_max_object_size_mb",
"mm_encoder_tp_mode",
"interleave_mm_strings",
"skip_mm_profiling",
}
# hf_config can control how the model looks! from vllm.config.utils import get_hash_factors, hash_factors
try:
hf_config_json = self.hf_config.to_json_string(use_diff=False)
except TypeError:
from transformers import PretrainedConfig
from vllm.utils.jsontree import json_map_leaves factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)
# Handle nested HF configs with unserializable values gracefully
hf_config_json = (
json.dumps(
json_map_leaves(
lambda v: v.to_dict()
if isinstance(v, PretrainedConfig)
else str(v),
self.hf_config.to_dict(),
),
indent=2,
sort_keys=True,
)
+ "\n"
)
factors.append(hf_config_json)
str_factors = str(factors)
assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
def _update_nested( def _update_nested(
self, self,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import os import os
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
@ -448,19 +447,41 @@ class ParallelConfig:
This hash is also used for DP worker configuration validation This hash is also used for DP worker configuration validation
to prevent hangs from mismatched collective communication patterns. to prevent hangs from mismatched collective communication patterns.
""" """
factors: list[Any] = [] ignored_factors = {
factors.append(self.pipeline_parallel_size) # Derived/runtime topology, networking, or launch details
factors.append(self.tensor_parallel_size) "data_parallel_rank",
factors.append(self.enable_expert_parallel) "data_parallel_rank_local",
factors.append(self.data_parallel_size) "data_parallel_backend",
factors.append(self.all2all_backend) "data_parallel_external_lb",
factors.append(self.enable_eplb) "data_parallel_hybrid_lb",
if self.enable_eplb: "data_parallel_master_ip",
factors.append(self.eplb_config.log_balancedness) "data_parallel_master_port",
factors.append(self.eplb_config.window_size) "_data_parallel_master_port_list",
factors.append(self.eplb_config.step_interval) "data_parallel_rpc_port",
factors.append(self.eplb_config.num_redundant_experts) "rank",
return hashlib.sha256(str(factors).encode()).hexdigest() "master_addr",
"master_port",
"node_rank",
"nnodes",
"max_parallel_loading_workers",
"disable_custom_all_reduce",
"ray_workers_use_nsight",
"ray_runtime_env",
"placement_group",
"distributed_executor_backend",
"worker_cls",
"sd_worker_cls",
"worker_extension_cls",
"_api_process_count",
"_api_process_rank",
}
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
# Explicitly include backend affecting env factor as before
factors["VLLM_ALL2ALL_BACKEND"] = str(envs.VLLM_ALL2ALL_BACKEND)
return hash_factors(factors)
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Set all2all_backend from env var if not specified, with deprecation warning # Set all2all_backend from env var if not specified, with deprecation warning

View File

@ -3,14 +3,19 @@
"""Utility functions for vLLM config dataclasses.""" """Utility functions for vLLM config dataclasses."""
import ast import ast
import enum
import hashlib
import inspect import inspect
import json
import pathlib
import textwrap import textwrap
from collections.abc import Iterable from collections.abc import Iterable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
from itertools import pairwise from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar from typing import TYPE_CHECKING, Any, Protocol, TypeVar
import regex as re import regex as re
import torch
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import runtime_checkable from typing_extensions import runtime_checkable
@ -176,3 +181,115 @@ def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
) )
processed_overrides[field_name] = value processed_overrides[field_name] = value
return replace(config, **processed_overrides) return replace(config, **processed_overrides)
def normalize_value(x):
"""Return a stable, JSON-serializable canonical form for hashing.
Order: primitives, special types (Enum, callable, torch.dtype, Path), then
generic containers (Mapping/Set/Sequence) with recursion.
"""
# Fast path
if x is None or isinstance(x, (bool, int, float, str)):
return x
# Enums: tag with FQN to avoid primitive collisions.
# Ex: Enum(1) vs int(1) -> ("module.QualName", value).
if isinstance(x, enum.Enum):
enum_type = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
return (enum_type, normalize_value(x.value))
# Classes (types) are accepted and canonicalized by their fully-qualified
# name (module.qualname) for a stable identifier.
# Instances are only accepted if they expose uuid(); otherwise they are
# rejected to avoid under-hashing object state.
# Callables: accept classes only; reject funcs/lambdas/methods.
# Used by LogitsProcessor types and ModelConfig.hf_overrides.
if isinstance(x, type):
module = getattr(x, "__module__", "")
qual = getattr(x, "__qualname__", getattr(x, "__name__", ""))
return ".".join([p for p in (module, qual) if p]) or repr(x)
# Prefer stable uuid identifiers for objects that provide them, even if
# they are callable instances (e.g., InductorPass wrappers).
if hasattr(x, "uuid") and callable(getattr(x, "uuid", None)):
return x.uuid()
if callable(x):
raise TypeError("normalize_value: function or callable instance unsupported")
# Torch dtype: stringify (torch.float64 -> "torch.float64").
# We rely on the string form here; dtype-bearing fields that need additional
# disambiguation should encode that at the config layer.
if isinstance(x, torch.dtype):
return str(x)
# Bytes
if isinstance(x, (bytes, bytearray)):
return x.hex()
# Paths (canonicalize)
if isinstance(x, pathlib.Path):
try:
return str(x.expanduser().resolve())
except Exception:
return str(x)
# Dataclasses: represent as (FQN, sorted(field,value) tuple) for stability.
if is_dataclass(x):
type_fqn = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
items = tuple(
(f.name, normalize_value(getattr(x, f.name)))
for f in sorted(fields(x), key=lambda f: f.name)
)
return (type_fqn, items)
# Containers (generic)
if isinstance(x, Mapping):
return tuple(sorted((str(k), normalize_value(v)) for k, v in x.items()))
if isinstance(x, Set):
return tuple(sorted(repr(normalize_value(v)) for v in x))
if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
return tuple(normalize_value(v) for v in x)
# PretrainedConfig
if hasattr(x, "to_json_string") and callable(x.to_json_string):
return x.to_json_string()
# Unsupported type: e.g., modules, generators, open files, or objects
# without a stable JSON/UUID representation. Hard-error to avoid
# under-hashing.
# If you hit this, either reshape your config to use supported primitives
# and containers, or extend normalize_value to provide a stable encoding
# (e.g., via uuid() or to_json_string()) for this type.
raise TypeError(
f"normalize_value: unsupported type '{type(x).__name__}'. "
"Ensure config values use supported primitives/containers or add a "
"stable representation for this type."
)
def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, object]:
"""Gets the factors used for hashing a config class.
- Includes all dataclass fields not in `ignored_factors`.
- Errors on non-normalizable values.
"""
factors: dict[str, object] = {}
for dc_field in fields(config):
factor = dc_field.name
if factor in ignored_factors:
continue
value = getattr(config, factor, None)
try:
factors[factor] = normalize_value(value)
except TypeError as e:
raise TypeError(
f"get_hash_factors: unsupported type for key '{factor}' "
f"({type(value).__name__})"
) from e
return factors
def hash_factors(items: dict[str, object]) -> str:
"""Return a SHA-256 hex digest of the canonical items structure."""
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()

View File

@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
import hashlib
import json import json
import logging
import os import os
import sys import sys
import tempfile import tempfile
@ -426,6 +426,8 @@ def get_vllm_port() -> int | None:
# --8<-- [start:env-vars-definition] # --8<-- [start:env-vars-definition]
logger = logging.getLogger(__name__)
environment_variables: dict[str, Callable[[], Any]] = { environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ================== # ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default), # Target device of vLLM, supporting [cuda (by default),
@ -1540,85 +1542,88 @@ def is_set(name: str):
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def compute_hash() -> str: def compile_factors() -> dict[str, object]:
""" """Return env vars used for torch.compile cache keys.
WARNING: Whenever a new key is added to this environment
variables, ensure that it is included in the factors list if
it affects the computation graph. For example, different values
of VLLM_PP_LAYER_PARTITION will generate different computation
graphs, so it is included in the factors list. The env vars that
affect the choice of different kernels or attention backends should
also be included in the factors list.
"""
# The values of envs may affects the computation graph. Start with every known vLLM env var; drop entries in `ignored_factors`;
# TODO(DefTruth): hash all environment variables? hash everything else. This keeps the cache key aligned across workers."""
# for key in environment_variables:
# factorize(key)
environment_variables_to_hash = [
"VLLM_PP_LAYER_PARTITION",
"VLLM_MLA_DISABLE",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK",
"VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE",
"VLLM_FUSED_MOE_CHUNK_SIZE",
"VLLM_FLASHINFER_MOE_BACKEND",
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION",
"VLLM_ATTENTION_BACKEND",
"VLLM_USE_FLASHINFER_SAMPLER",
"VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM",
"VLLM_MOE_USE_DEEP_GEMM",
"VLLM_USE_DEEP_GEMM_E8M0",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP16",
"VLLM_USE_FLASHINFER_MOE_FP8",
"VLLM_USE_FLASHINFER_MOE_FP4",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE",
"VLLM_USE_CUDNN_PREFILL",
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
"VLLM_USE_TRTLLM_ATTENTION",
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
"VLLM_ROCM_USE_AITER",
"VLLM_ROCM_USE_AITER_PAGED_ATTN",
"VLLM_ROCM_USE_AITER_LINEAR",
"VLLM_ROCM_USE_AITER_MOE",
"VLLM_ROCM_USE_AITER_RMSNORM",
"VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
"VLLM_ROCM_USE_AITER_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ROCM_USE_AITER_TRITON_GEMM",
"VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING",
"VLLM_ROCM_MOE_PADDING",
"VLLM_ROCM_CUSTOM_PAGED_ATTN",
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION",
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
"VLLM_NVFP4_GEMM_BACKEND",
"VLLM_USE_FBGEMM",
"VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE",
"VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL",
]
for key in environment_variables_to_hash:
# if this goes out of sync with environment_variables,
# it's not a user error, it's a bug
assert key in environment_variables, (
"Please update environment_variables_to_hash in envs.py"
)
factors = [environment_variables[key]() for key in environment_variables_to_hash] ignored_factors: set[str] = {
"MAX_JOBS",
"VLLM_RPC_BASE_PATH",
"VLLM_USE_MODELSCOPE",
"VLLM_RINGBUFFER_WARNING_INTERVAL",
"VLLM_DEBUG_DUMP_PATH",
"VLLM_PORT",
"VLLM_CACHE_ROOT",
"LD_LIBRARY_PATH",
"VLLM_SERVER_DEV_MODE",
"VLLM_DP_MASTER_IP",
"VLLM_DP_MASTER_PORT",
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS",
"VLLM_CI_USE_S3",
"VLLM_MODEL_REDIRECT_PATH",
"VLLM_HOST_IP",
"S3_ACCESS_KEY_ID",
"S3_SECRET_ACCESS_KEY",
"S3_ENDPOINT_URL",
"VLLM_USAGE_STATS_SERVER",
"VLLM_NO_USAGE_STATS",
"VLLM_DO_NOT_TRACK",
"VLLM_LOGGING_LEVEL",
"VLLM_LOGGING_PREFIX",
"VLLM_LOGGING_STREAM",
"VLLM_LOGGING_CONFIG_PATH",
"VLLM_LOG_STATS_INTERVAL",
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE",
"VLLM_TUNED_CONFIG_FOLDER",
"VLLM_ENGINE_ITERATION_TIMEOUT_S",
"VLLM_HTTP_TIMEOUT_KEEP_ALIVE",
"VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS",
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH",
"VLLM_SLEEP_WHEN_IDLE",
"VLLM_IMAGE_FETCH_TIMEOUT",
"VLLM_VIDEO_FETCH_TIMEOUT",
"VLLM_AUDIO_FETCH_TIMEOUT",
"VLLM_MEDIA_URL_ALLOW_REDIRECTS",
"VLLM_MEDIA_LOADING_THREAD_COUNT",
"VLLM_MAX_AUDIO_CLIP_FILESIZE_MB",
"VLLM_VIDEO_LOADER_BACKEND",
"VLLM_MEDIA_CONNECTOR",
"VLLM_ASSETS_CACHE",
"VLLM_ASSETS_CACHE_MODEL_CLEAN",
"VLLM_MM_INPUT_CACHE_GIB",
"VLLM_WORKER_MULTIPROC_METHOD",
"VLLM_ENABLE_V1_MULTIPROCESSING",
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
"VLLM_CPU_KVCACHE_SPACE",
"VLLM_CPU_OMP_THREADS_BIND",
"VLLM_CPU_NUM_OF_RESERVED_CPU",
"VLLM_CPU_MOE_PREPACK",
"VLLM_CPU_SGL_KERNEL",
"VLLM_TEST_FORCE_LOAD_FORMAT",
"LOCAL_RANK",
"CUDA_VISIBLE_DEVICES",
}
from vllm.config.utils import normalize_value
factors: dict[str, object] = {}
for factor, getter in environment_variables.items():
if factor in ignored_factors:
continue
try:
raw = getter()
except Exception as exc: # pragma: no cover - defensive logging
logger.warning(
"Skipping environment variable %s while hashing compile factors: %s",
factor,
exc,
)
continue
factors[factor] = normalize_value(raw)
ray_noset_env_vars = [ ray_noset_env_vars = [
# Refer to # Refer to
@ -1641,8 +1646,8 @@ def compute_hash() -> str:
"RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR",
"RAY_EXPERIMENTAL_NOSET_RBLN_RT_VISIBLE_DEVICES", "RAY_EXPERIMENTAL_NOSET_RBLN_RT_VISIBLE_DEVICES",
] ]
factors.extend([os.getenv(var) for var in ray_noset_env_vars])
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() for var in ray_noset_env_vars:
factors[var] = normalize_value(os.getenv(var))
return hash_str return factors

View File

@ -2,9 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logging_utils.formatter import NewLineFormatter from vllm.logging_utils.formatter import NewLineFormatter
from vllm.logging_utils.lazy import lazy
from vllm.logging_utils.log_time import logtime from vllm.logging_utils.log_time import logtime
__all__ = [ __all__ = [
"NewLineFormatter", "NewLineFormatter",
"lazy",
"logtime", "logtime",
] ]

View File

@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
class lazy:
"""Wrap a zero-argument callable evaluated only during log formatting."""
__slots__ = ("_factory",)
def __init__(self, factory: Callable[[], Any]) -> None:
self._factory = factory
def __str__(self) -> str:
return str(self._factory())
def __repr__(self) -> str:
return str(self)