mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:45:01 +08:00
[torch.compile] caching of config fields should be opt-out by default (#26468)
Signed-off-by: vnadathur <glvikramn@gmail.com> Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: Srreyansh Sethi <srreyansh.sethi@gmail.com> Signed-off-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Co-authored-by: WorldExplored <srreyansh.sethi@gmail.com> Co-authored-by: Srreyansh Sethi <107075589+worldexplored@users.noreply.github.com> Co-authored-by: vnadathur <236933696+vnadathur@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
2c8b9182b5
commit
1ffe934c8a
166
tests/config/test_config_utils.py
Normal file
166
tests/config/test_config_utils.py
Normal file
@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.utils import get_hash_factors, hash_factors, normalize_value
|
||||
|
||||
# Helpers
|
||||
|
||||
|
||||
def endswith_fqname(obj, suffix: str) -> bool:
|
||||
# normalize_value(type) returns fully-qualified name
|
||||
# Compare suffix to avoid brittle import paths.
|
||||
out = normalize_value(obj)
|
||||
return isinstance(out, str) and out.endswith(suffix)
|
||||
|
||||
|
||||
def expected_path(p_str: str = ".") -> str:
|
||||
import pathlib
|
||||
|
||||
p = pathlib.Path(p_str)
|
||||
return p.expanduser().resolve().as_posix()
|
||||
|
||||
|
||||
# Minimal dataclass to test get_hash_factors.
|
||||
# Avoid importing heavy vLLM configs.
|
||||
@dataclass
|
||||
class SimpleConfig:
|
||||
a: object
|
||||
b: object | None = None
|
||||
|
||||
|
||||
class DummyLogprobsMode(Enum):
|
||||
RAW_LOGITS = "raw_logits"
|
||||
|
||||
|
||||
def test_hash_factors_deterministic():
|
||||
"""Test that hash_factors produces consistent SHA-256 hashes"""
|
||||
factors = {"a": 1, "b": "test"}
|
||||
hash1 = hash_factors(factors)
|
||||
hash2 = hash_factors(factors)
|
||||
|
||||
assert hash1 == hash2
|
||||
# Dict key insertion order should not affect the hash.
|
||||
factors_reordered = {"b": "test", "a": 1}
|
||||
assert hash_factors(factors_reordered) == hash1
|
||||
assert len(hash1) == 64
|
||||
assert all(c in "0123456789abcdef" for c in hash1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inp, expected",
|
||||
[
|
||||
(None, None),
|
||||
(True, True),
|
||||
(1, 1),
|
||||
(1.0, 1.0),
|
||||
("x", "x"),
|
||||
(b"ab", "6162"),
|
||||
(bytearray(b"ab"), "6162"),
|
||||
([1, 2], (1, 2)),
|
||||
({"b": 2, "a": 1}, (("a", 1), ("b", 2))),
|
||||
],
|
||||
)
|
||||
def test_normalize_value_matrix(inp, expected):
|
||||
"""Parametric input→expected normalization table."""
|
||||
assert normalize_value(inp) == expected
|
||||
|
||||
|
||||
def test_normalize_value_enum():
|
||||
# Enums normalize to (module.QualName, value).
|
||||
# DummyLogprobsMode uses a string payload.
|
||||
out = normalize_value(DummyLogprobsMode.RAW_LOGITS)
|
||||
assert isinstance(out, tuple)
|
||||
assert out[0].endswith("DummyLogprobsMode")
|
||||
# Expect string payload 'raw_logits'.
|
||||
assert out[1] == "raw_logits"
|
||||
|
||||
|
||||
def test_normalize_value_set_order_insensitive():
|
||||
# Sets are unordered; normalize_value sorts elements for determinism.
|
||||
assert normalize_value({3, 1, 2}) == normalize_value({1, 2, 3})
|
||||
|
||||
|
||||
def test_normalize_value_path_normalization():
|
||||
from pathlib import Path # local import to avoid global dependency
|
||||
|
||||
# Paths expand/resolve to absolute strings.
|
||||
# Stabilizes hashing across working dirs.
|
||||
assert normalize_value(Path(".")) == expected_path(".")
|
||||
|
||||
|
||||
def test_normalize_value_uuid_and_to_json():
|
||||
# Objects may normalize via uuid() or to_json_string().
|
||||
class HasUUID:
|
||||
def uuid(self):
|
||||
return "test-uuid"
|
||||
|
||||
class ToJson:
|
||||
def to_json_string(self):
|
||||
return '{"x":1}'
|
||||
|
||||
assert normalize_value(HasUUID()) == "test-uuid"
|
||||
assert normalize_value(ToJson()) == '{"x":1}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad",
|
||||
[
|
||||
(lambda x: x),
|
||||
(type("CallableInstance", (), {"__call__": lambda self: 0}))(),
|
||||
(lambda: (lambda: 0))(), # nested function instance
|
||||
],
|
||||
)
|
||||
def test_error_cases(bad):
|
||||
"""Inputs expected to raise TypeError."""
|
||||
# Reject functions/lambdas/callable instances
|
||||
# to avoid under-hashing.
|
||||
with pytest.raises(TypeError):
|
||||
normalize_value(bad)
|
||||
|
||||
|
||||
def test_enum_vs_int_disambiguation():
|
||||
# int stays primitive
|
||||
nf_int = normalize_value(1)
|
||||
assert nf_int == 1
|
||||
|
||||
# enum becomes ("module.QualName", value)
|
||||
nf_enum = normalize_value(DummyLogprobsMode.RAW_LOGITS)
|
||||
assert isinstance(nf_enum, tuple) and len(nf_enum) == 2
|
||||
enum_type, enum_val = nf_enum
|
||||
assert enum_type.endswith(".DummyLogprobsMode")
|
||||
assert enum_val == "raw_logits"
|
||||
|
||||
# Build factor dicts from configs with int vs enum
|
||||
f_int = get_hash_factors(SimpleConfig(1), set())
|
||||
f_enum = get_hash_factors(SimpleConfig(DummyLogprobsMode.RAW_LOGITS), set())
|
||||
# The int case remains a primitive value
|
||||
assert f_int["a"] == 1
|
||||
# The enum case becomes a tagged tuple ("module.QualName", "raw_logits")
|
||||
assert isinstance(f_enum["a"], tuple) and f_enum["a"][1] == "raw_logits"
|
||||
# Factor dicts must differ so we don't collide primitives with Enums.
|
||||
assert f_int != f_enum
|
||||
# Hash digests must differ correspondingly
|
||||
assert hash_factors(f_int) != hash_factors(f_enum)
|
||||
|
||||
# Hash functions produce stable hex strings
|
||||
h_int = hash_factors(f_int)
|
||||
h_enum = hash_factors(f_enum)
|
||||
assert isinstance(h_int, str) and len(h_int) == 64
|
||||
assert isinstance(h_enum, str) and len(h_enum) == 64
|
||||
|
||||
|
||||
def test_classes_are_types():
|
||||
"""Types normalize to FQNs; include real vLLM types."""
|
||||
# Only classes allowed; functions/lambdas are rejected.
|
||||
# Canonical form is the fully-qualified name.
|
||||
assert isinstance(normalize_value(str), str)
|
||||
|
||||
class LocalDummy:
|
||||
pass
|
||||
|
||||
assert endswith_fqname(LocalDummy, ".LocalDummy")
|
||||
@ -4,12 +4,14 @@
|
||||
import ast
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import operator
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -23,7 +25,9 @@ from vllm.compilation.partition_rules import (
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.utils import hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import lazy
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
@ -580,35 +584,47 @@ class VllmBackend:
|
||||
def __call__(
|
||||
self, graph: fx.GraphModule, example_inputs
|
||||
) -> VllmSerializableFunction:
|
||||
from .caching import _compute_code_hash, compilation_config_hash_factors
|
||||
|
||||
vllm_config = self.vllm_config
|
||||
# Minimal hashing here with existing utilities, reused below.
|
||||
|
||||
env_factors = envs.compile_factors()
|
||||
env_hash = hash_factors(env_factors)
|
||||
# Compute config/compiler/code hashes once and reuse
|
||||
config_hash = vllm_config.compute_hash()
|
||||
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
|
||||
forward_code_files = list(sorted(self.compilation_config.traced_files))
|
||||
|
||||
logger.debug(
|
||||
"Traced files (to be considered for compilation cache):\n%s",
|
||||
lazy(lambda: "\n".join(forward_code_files)),
|
||||
)
|
||||
hash_content = []
|
||||
for filepath in forward_code_files:
|
||||
hash_content.append(filepath)
|
||||
if filepath == "<string>":
|
||||
# This means the function was dynamically generated, with
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
try:
|
||||
with open(filepath) as f:
|
||||
hash_content.append(f.read())
|
||||
except Exception:
|
||||
logger.warning("Failed to read file %s", filepath)
|
||||
continue
|
||||
code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
|
||||
# Clear after consumption
|
||||
self.compilation_config.traced_files.clear()
|
||||
if not self.compilation_config.cache_dir:
|
||||
# no provided cache dir, generate one based on the known factors
|
||||
# that affects the compilation. if none of the factors change,
|
||||
# the cache dir will be the same so that we can reuse the compiled
|
||||
# graph.
|
||||
|
||||
factors = compilation_config_hash_factors(vllm_config)
|
||||
# 2. factors come from the code files that are traced by Dynamo (
|
||||
# it mainly summarizes how the model is used in forward pass)
|
||||
code_hash = _compute_code_hash(self.compilation_config.traced_files)
|
||||
self.compilation_config.traced_files.clear()
|
||||
factors.append(code_hash)
|
||||
|
||||
# 3. compiler hash
|
||||
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
|
||||
factors.append(compiler_hash)
|
||||
|
||||
# combine all factors to generate the cache dir
|
||||
hash_key = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
|
||||
factors = [env_hash, config_hash, code_hash, compiler_hash]
|
||||
# Use SHA-256 for cache key hashing to be consistent across
|
||||
# compute_hash functions. Truncate for a short cache dir name.
|
||||
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_compile_cache",
|
||||
hash_key,
|
||||
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
|
||||
)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
|
||||
@ -621,6 +637,7 @@ class VllmBackend:
|
||||
os.makedirs(local_cache_dir, exist_ok=True)
|
||||
self.compilation_config.local_cache_dir = local_cache_dir
|
||||
|
||||
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
|
||||
disable_cache = not is_compile_cache_enabled(
|
||||
self.compilation_config.inductor_compile_config
|
||||
)
|
||||
@ -638,6 +655,50 @@ class VllmBackend:
|
||||
local_cache_dir, disable_cache, self.prefix
|
||||
)
|
||||
|
||||
# Reuses existing cache key
|
||||
|
||||
logger.debug(
|
||||
"torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
|
||||
env_hash,
|
||||
config_hash,
|
||||
compiler_hash,
|
||||
code_hash,
|
||||
local_cache_dir,
|
||||
)
|
||||
|
||||
# Persist and log only hash-relevant factors together.
|
||||
try:
|
||||
logger.debug(
|
||||
"Compile env factors (raw):\n%s\nVllm config hash: %s",
|
||||
lazy(partial(pprint.pformat, env_factors, width=120)),
|
||||
config_hash,
|
||||
)
|
||||
meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
|
||||
if not os.path.exists(meta_path):
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"env": env_factors, # raw factors used for env_hash
|
||||
"config_hash": config_hash,
|
||||
"code_hash": code_hash,
|
||||
"compiler_hash": compiler_hash,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
)
|
||||
except Exception:
|
||||
# Best-effort only; metadata write failures are non-fatal.
|
||||
logger.warning(
|
||||
(
|
||||
"Could not write compile cache metadata at %s; continuing without "
|
||||
"metadata. Compiled cache remains valid; diagnostics may be "
|
||||
"limited."
|
||||
),
|
||||
local_cache_dir,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# when dynamo calls the backend, it means the bytecode
|
||||
# transform and analysis are done
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
|
||||
@ -127,7 +127,7 @@ class PostGradPassManager(CustomGraphPass):
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
state = {"pass_config": self.pass_config.uuid(), "passes": []}
|
||||
state = {"pass_config": self.pass_config.compute_hash(), "passes": []}
|
||||
for pass_ in self.passes:
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
@ -160,13 +159,29 @@ class CacheConfig:
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
factors.append(self.cache_dtype)
|
||||
factors.append(self.mamba_cache_dtype)
|
||||
factors.append(self.mamba_ssm_cache_dtype)
|
||||
# `cpu_offload_gb` does not use `torch.compile` yet.
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
ignored_factors = {
|
||||
# Runtime/derived knobs that don't affect compiled graph shape
|
||||
"gpu_memory_utilization",
|
||||
"swap_space",
|
||||
"is_attention_free",
|
||||
"num_gpu_blocks_override",
|
||||
"enable_prefix_caching",
|
||||
"prefix_caching_hash_algo",
|
||||
# `cpu_offload_gb` does not use `torch.compile` yet.
|
||||
"cpu_offload_gb",
|
||||
"cpu_kvcache_space_bytes",
|
||||
"mamba_page_size_padded",
|
||||
# Post-init/derived counters
|
||||
"num_gpu_blocks",
|
||||
"num_cpu_blocks",
|
||||
# WIP feature toggle not impacting compiled graph shape
|
||||
"kv_sharing_fast_prefill",
|
||||
}
|
||||
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
return hash_factors(factors)
|
||||
|
||||
def metrics_info(self):
|
||||
# convert cache_config to dict(key: str, value: str) for prometheus
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import hashlib
|
||||
from collections import Counter
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict, field
|
||||
@ -160,7 +159,7 @@ class PassConfig:
|
||||
current_platform.get_device_capability().to_int(), {}
|
||||
)
|
||||
|
||||
def uuid(self):
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
@ -506,28 +505,33 @@ class CompilationConfig:
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
factors.append(self.mode)
|
||||
factors.append(self.backend)
|
||||
factors.append(self.custom_ops)
|
||||
factors.append(self.splitting_ops)
|
||||
factors.append(self.use_inductor)
|
||||
factors.append(self.use_inductor_graph_partition)
|
||||
factors.append(self.inductor_compile_config)
|
||||
factors.append(self.inductor_passes)
|
||||
factors.append(self.pass_config.uuid())
|
||||
factors.append(self.compile_cache_save_format)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
# Opt-out: default-include declared fields; keep a tiny exclude set;
|
||||
# normalize types; keep SHA-256. For nested opaque configs, include a
|
||||
# stable identifier (e.g., pass_config.compute_hash()) instead of object id.
|
||||
|
||||
ignored_factors = {
|
||||
# Paths/dirs and runtime/metrics that don’t affect compiled graph
|
||||
"debug_dump_path",
|
||||
"cache_dir",
|
||||
"local_cache_dir",
|
||||
"bs_to_padded_graph_size",
|
||||
"traced_files",
|
||||
"compilation_time",
|
||||
"static_forward_context",
|
||||
"pass_config", # handled separately below
|
||||
}
|
||||
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
factors["pass_config"] = self.pass_config.compute_hash()
|
||||
return hash_factors(factors)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
exclude = {
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar, field
|
||||
@ -18,7 +16,7 @@ import vllm.envs as envs
|
||||
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import assert_hashable, config, getattr_iter
|
||||
from vllm.config.utils import config, getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import (
|
||||
@ -324,50 +322,50 @@ class ModelConfig:
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
factors.append(self.model)
|
||||
factors.append(self.dtype)
|
||||
factors.append(self.quantization)
|
||||
factors.append(self.revision)
|
||||
factors.append(self.code_revision)
|
||||
factors.append(self.max_model_len)
|
||||
factors.append(self.max_logprobs)
|
||||
factors.append(self.disable_sliding_window)
|
||||
factors.append(self.trust_remote_code)
|
||||
factors.append(self.generation_config)
|
||||
factors.append(self.model_impl)
|
||||
factors.append(self.override_generation_config)
|
||||
factors.append(self.video_pruning_rate)
|
||||
factors.append(self.enable_prompt_embeds)
|
||||
ignored_factors = {
|
||||
"runner",
|
||||
"convert",
|
||||
"task",
|
||||
"tokenizer",
|
||||
"tokenizer_mode",
|
||||
"seed",
|
||||
"hf_config_path",
|
||||
"allowed_local_media_path",
|
||||
"allowed_media_domains",
|
||||
"tokenizer_revision",
|
||||
"spec_target_max_model_len",
|
||||
"enforce_eager",
|
||||
"logprobs_mode",
|
||||
"disable_cascade_attn",
|
||||
"skip_tokenizer_init",
|
||||
"enable_prompt_embeds",
|
||||
"served_model_name",
|
||||
"config_format",
|
||||
"hf_token",
|
||||
"hf_overrides",
|
||||
"logits_processor_pattern",
|
||||
"enable_sleep_mode",
|
||||
"override_attention_dtype",
|
||||
"logits_processors",
|
||||
"io_processor_plugin",
|
||||
"pooler_config",
|
||||
"override_pooler_config",
|
||||
"multimodal_config",
|
||||
"limit_mm_per_prompt",
|
||||
"media_io_kwargs",
|
||||
"mm_processor_kwargs",
|
||||
"mm_processor_cache_gb",
|
||||
"mm_processor_cache_type",
|
||||
"mm_shm_cache_max_object_size_mb",
|
||||
"mm_encoder_tp_mode",
|
||||
"interleave_mm_strings",
|
||||
"skip_mm_profiling",
|
||||
}
|
||||
|
||||
# hf_config can control how the model looks!
|
||||
try:
|
||||
hf_config_json = self.hf_config.to_json_string(use_diff=False)
|
||||
except TypeError:
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
|
||||
# Handle nested HF configs with unserializable values gracefully
|
||||
hf_config_json = (
|
||||
json.dumps(
|
||||
json_map_leaves(
|
||||
lambda v: v.to_dict()
|
||||
if isinstance(v, PretrainedConfig)
|
||||
else str(v),
|
||||
self.hf_config.to_dict(),
|
||||
),
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
factors.append(hf_config_json)
|
||||
|
||||
str_factors = str(factors)
|
||||
assert_hashable(str_factors)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
return hash_factors(factors)
|
||||
|
||||
def _update_nested(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
@ -448,19 +447,41 @@ class ParallelConfig:
|
||||
This hash is also used for DP worker configuration validation
|
||||
to prevent hangs from mismatched collective communication patterns.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
factors.append(self.pipeline_parallel_size)
|
||||
factors.append(self.tensor_parallel_size)
|
||||
factors.append(self.enable_expert_parallel)
|
||||
factors.append(self.data_parallel_size)
|
||||
factors.append(self.all2all_backend)
|
||||
factors.append(self.enable_eplb)
|
||||
if self.enable_eplb:
|
||||
factors.append(self.eplb_config.log_balancedness)
|
||||
factors.append(self.eplb_config.window_size)
|
||||
factors.append(self.eplb_config.step_interval)
|
||||
factors.append(self.eplb_config.num_redundant_experts)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
ignored_factors = {
|
||||
# Derived/runtime topology, networking, or launch details
|
||||
"data_parallel_rank",
|
||||
"data_parallel_rank_local",
|
||||
"data_parallel_backend",
|
||||
"data_parallel_external_lb",
|
||||
"data_parallel_hybrid_lb",
|
||||
"data_parallel_master_ip",
|
||||
"data_parallel_master_port",
|
||||
"_data_parallel_master_port_list",
|
||||
"data_parallel_rpc_port",
|
||||
"rank",
|
||||
"master_addr",
|
||||
"master_port",
|
||||
"node_rank",
|
||||
"nnodes",
|
||||
"max_parallel_loading_workers",
|
||||
"disable_custom_all_reduce",
|
||||
"ray_workers_use_nsight",
|
||||
"ray_runtime_env",
|
||||
"placement_group",
|
||||
"distributed_executor_backend",
|
||||
"worker_cls",
|
||||
"sd_worker_cls",
|
||||
"worker_extension_cls",
|
||||
"_api_process_count",
|
||||
"_api_process_rank",
|
||||
}
|
||||
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
# Explicitly include backend affecting env factor as before
|
||||
factors["VLLM_ALL2ALL_BACKEND"] = str(envs.VLLM_ALL2ALL_BACKEND)
|
||||
return hash_factors(factors)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Set all2all_backend from env var if not specified, with deprecation warning
|
||||
|
||||
@ -3,14 +3,19 @@
|
||||
"""Utility functions for vLLM config dataclasses."""
|
||||
|
||||
import ast
|
||||
import enum
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import pathlib
|
||||
import textwrap
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Mapping, Sequence, Set
|
||||
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
|
||||
from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
@ -176,3 +181,115 @@ def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
|
||||
)
|
||||
processed_overrides[field_name] = value
|
||||
return replace(config, **processed_overrides)
|
||||
|
||||
|
||||
def normalize_value(x):
|
||||
"""Return a stable, JSON-serializable canonical form for hashing.
|
||||
Order: primitives, special types (Enum, callable, torch.dtype, Path), then
|
||||
generic containers (Mapping/Set/Sequence) with recursion.
|
||||
"""
|
||||
# Fast path
|
||||
if x is None or isinstance(x, (bool, int, float, str)):
|
||||
return x
|
||||
|
||||
# Enums: tag with FQN to avoid primitive collisions.
|
||||
# Ex: Enum(1) vs int(1) -> ("module.QualName", value).
|
||||
if isinstance(x, enum.Enum):
|
||||
enum_type = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
||||
return (enum_type, normalize_value(x.value))
|
||||
|
||||
# Classes (types) are accepted and canonicalized by their fully-qualified
|
||||
# name (module.qualname) for a stable identifier.
|
||||
# Instances are only accepted if they expose uuid(); otherwise they are
|
||||
# rejected to avoid under-hashing object state.
|
||||
|
||||
# Callables: accept classes only; reject funcs/lambdas/methods.
|
||||
# Used by LogitsProcessor types and ModelConfig.hf_overrides.
|
||||
if isinstance(x, type):
|
||||
module = getattr(x, "__module__", "")
|
||||
qual = getattr(x, "__qualname__", getattr(x, "__name__", ""))
|
||||
return ".".join([p for p in (module, qual) if p]) or repr(x)
|
||||
|
||||
# Prefer stable uuid identifiers for objects that provide them, even if
|
||||
# they are callable instances (e.g., InductorPass wrappers).
|
||||
if hasattr(x, "uuid") and callable(getattr(x, "uuid", None)):
|
||||
return x.uuid()
|
||||
|
||||
if callable(x):
|
||||
raise TypeError("normalize_value: function or callable instance unsupported")
|
||||
|
||||
# Torch dtype: stringify (torch.float64 -> "torch.float64").
|
||||
# We rely on the string form here; dtype-bearing fields that need additional
|
||||
# disambiguation should encode that at the config layer.
|
||||
if isinstance(x, torch.dtype):
|
||||
return str(x)
|
||||
|
||||
# Bytes
|
||||
if isinstance(x, (bytes, bytearray)):
|
||||
return x.hex()
|
||||
|
||||
# Paths (canonicalize)
|
||||
if isinstance(x, pathlib.Path):
|
||||
try:
|
||||
return str(x.expanduser().resolve())
|
||||
except Exception:
|
||||
return str(x)
|
||||
|
||||
# Dataclasses: represent as (FQN, sorted(field,value) tuple) for stability.
|
||||
if is_dataclass(x):
|
||||
type_fqn = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
||||
items = tuple(
|
||||
(f.name, normalize_value(getattr(x, f.name)))
|
||||
for f in sorted(fields(x), key=lambda f: f.name)
|
||||
)
|
||||
return (type_fqn, items)
|
||||
|
||||
# Containers (generic)
|
||||
if isinstance(x, Mapping):
|
||||
return tuple(sorted((str(k), normalize_value(v)) for k, v in x.items()))
|
||||
if isinstance(x, Set):
|
||||
return tuple(sorted(repr(normalize_value(v)) for v in x))
|
||||
if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
|
||||
return tuple(normalize_value(v) for v in x)
|
||||
|
||||
# PretrainedConfig
|
||||
if hasattr(x, "to_json_string") and callable(x.to_json_string):
|
||||
return x.to_json_string()
|
||||
|
||||
# Unsupported type: e.g., modules, generators, open files, or objects
|
||||
# without a stable JSON/UUID representation. Hard-error to avoid
|
||||
# under-hashing.
|
||||
# If you hit this, either reshape your config to use supported primitives
|
||||
# and containers, or extend normalize_value to provide a stable encoding
|
||||
# (e.g., via uuid() or to_json_string()) for this type.
|
||||
raise TypeError(
|
||||
f"normalize_value: unsupported type '{type(x).__name__}'. "
|
||||
"Ensure config values use supported primitives/containers or add a "
|
||||
"stable representation for this type."
|
||||
)
|
||||
|
||||
|
||||
def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, object]:
|
||||
"""Gets the factors used for hashing a config class.
|
||||
- Includes all dataclass fields not in `ignored_factors`.
|
||||
- Errors on non-normalizable values.
|
||||
"""
|
||||
factors: dict[str, object] = {}
|
||||
for dc_field in fields(config):
|
||||
factor = dc_field.name
|
||||
if factor in ignored_factors:
|
||||
continue
|
||||
value = getattr(config, factor, None)
|
||||
try:
|
||||
factors[factor] = normalize_value(value)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"get_hash_factors: unsupported type for key '{factor}' "
|
||||
f"({type(value).__name__})"
|
||||
) from e
|
||||
return factors
|
||||
|
||||
|
||||
def hash_factors(items: dict[str, object]) -> str:
|
||||
"""Return a SHA-256 hex digest of the canonical items structure."""
|
||||
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
167
vllm/envs.py
167
vllm/envs.py
@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
@ -426,6 +426,8 @@ def get_vllm_port() -> int | None:
|
||||
|
||||
# --8<-- [start:env-vars-definition]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# ================== Installation Time Env Vars ==================
|
||||
# Target device of vLLM, supporting [cuda (by default),
|
||||
@ -1540,85 +1542,88 @@ def is_set(name: str):
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def compute_hash() -> str:
|
||||
"""
|
||||
WARNING: Whenever a new key is added to this environment
|
||||
variables, ensure that it is included in the factors list if
|
||||
it affects the computation graph. For example, different values
|
||||
of VLLM_PP_LAYER_PARTITION will generate different computation
|
||||
graphs, so it is included in the factors list. The env vars that
|
||||
affect the choice of different kernels or attention backends should
|
||||
also be included in the factors list.
|
||||
"""
|
||||
def compile_factors() -> dict[str, object]:
|
||||
"""Return env vars used for torch.compile cache keys.
|
||||
|
||||
# The values of envs may affects the computation graph.
|
||||
# TODO(DefTruth): hash all environment variables?
|
||||
# for key in environment_variables:
|
||||
# factorize(key)
|
||||
environment_variables_to_hash = [
|
||||
"VLLM_PP_LAYER_PARTITION",
|
||||
"VLLM_MLA_DISABLE",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||
"VLLM_USE_TRITON_AWQ",
|
||||
"VLLM_DP_RANK",
|
||||
"VLLM_DP_SIZE",
|
||||
"VLLM_USE_STANDALONE_COMPILE",
|
||||
"VLLM_FUSED_MOE_CHUNK_SIZE",
|
||||
"VLLM_FLASHINFER_MOE_BACKEND",
|
||||
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION",
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
"VLLM_USE_FLASHINFER_SAMPLER",
|
||||
"VLLM_DISABLED_KERNELS",
|
||||
"VLLM_USE_DEEP_GEMM",
|
||||
"VLLM_MOE_USE_DEEP_GEMM",
|
||||
"VLLM_USE_DEEP_GEMM_E8M0",
|
||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP16",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP8",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP4",
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
|
||||
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE",
|
||||
"VLLM_USE_CUDNN_PREFILL",
|
||||
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
|
||||
"VLLM_USE_TRTLLM_ATTENTION",
|
||||
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
|
||||
"VLLM_ROCM_USE_AITER",
|
||||
"VLLM_ROCM_USE_AITER_PAGED_ATTN",
|
||||
"VLLM_ROCM_USE_AITER_LINEAR",
|
||||
"VLLM_ROCM_USE_AITER_MOE",
|
||||
"VLLM_ROCM_USE_AITER_RMSNORM",
|
||||
"VLLM_ROCM_USE_AITER_MLA",
|
||||
"VLLM_ROCM_USE_AITER_MHA",
|
||||
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
|
||||
"VLLM_ROCM_USE_AITER_TRITON_ROPE",
|
||||
"VLLM_ROCM_USE_AITER_FP8BMM",
|
||||
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
|
||||
"VLLM_ROCM_USE_AITER_TRITON_GEMM",
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM",
|
||||
"VLLM_ROCM_FP8_PADDING",
|
||||
"VLLM_ROCM_MOE_PADDING",
|
||||
"VLLM_ROCM_CUSTOM_PAGED_ATTN",
|
||||
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION",
|
||||
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
|
||||
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
|
||||
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
||||
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
||||
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||
"VLLM_NVFP4_GEMM_BACKEND",
|
||||
"VLLM_USE_FBGEMM",
|
||||
"VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE",
|
||||
"VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL",
|
||||
]
|
||||
for key in environment_variables_to_hash:
|
||||
# if this goes out of sync with environment_variables,
|
||||
# it's not a user error, it's a bug
|
||||
assert key in environment_variables, (
|
||||
"Please update environment_variables_to_hash in envs.py"
|
||||
)
|
||||
Start with every known vLLM env var; drop entries in `ignored_factors`;
|
||||
hash everything else. This keeps the cache key aligned across workers."""
|
||||
|
||||
factors = [environment_variables[key]() for key in environment_variables_to_hash]
|
||||
ignored_factors: set[str] = {
|
||||
"MAX_JOBS",
|
||||
"VLLM_RPC_BASE_PATH",
|
||||
"VLLM_USE_MODELSCOPE",
|
||||
"VLLM_RINGBUFFER_WARNING_INTERVAL",
|
||||
"VLLM_DEBUG_DUMP_PATH",
|
||||
"VLLM_PORT",
|
||||
"VLLM_CACHE_ROOT",
|
||||
"LD_LIBRARY_PATH",
|
||||
"VLLM_SERVER_DEV_MODE",
|
||||
"VLLM_DP_MASTER_IP",
|
||||
"VLLM_DP_MASTER_PORT",
|
||||
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS",
|
||||
"VLLM_CI_USE_S3",
|
||||
"VLLM_MODEL_REDIRECT_PATH",
|
||||
"VLLM_HOST_IP",
|
||||
"S3_ACCESS_KEY_ID",
|
||||
"S3_SECRET_ACCESS_KEY",
|
||||
"S3_ENDPOINT_URL",
|
||||
"VLLM_USAGE_STATS_SERVER",
|
||||
"VLLM_NO_USAGE_STATS",
|
||||
"VLLM_DO_NOT_TRACK",
|
||||
"VLLM_LOGGING_LEVEL",
|
||||
"VLLM_LOGGING_PREFIX",
|
||||
"VLLM_LOGGING_STREAM",
|
||||
"VLLM_LOGGING_CONFIG_PATH",
|
||||
"VLLM_LOG_STATS_INTERVAL",
|
||||
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE",
|
||||
"VLLM_TUNED_CONFIG_FOLDER",
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S",
|
||||
"VLLM_HTTP_TIMEOUT_KEEP_ALIVE",
|
||||
"VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS",
|
||||
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH",
|
||||
"VLLM_SLEEP_WHEN_IDLE",
|
||||
"VLLM_IMAGE_FETCH_TIMEOUT",
|
||||
"VLLM_VIDEO_FETCH_TIMEOUT",
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT",
|
||||
"VLLM_MEDIA_URL_ALLOW_REDIRECTS",
|
||||
"VLLM_MEDIA_LOADING_THREAD_COUNT",
|
||||
"VLLM_MAX_AUDIO_CLIP_FILESIZE_MB",
|
||||
"VLLM_VIDEO_LOADER_BACKEND",
|
||||
"VLLM_MEDIA_CONNECTOR",
|
||||
"VLLM_ASSETS_CACHE",
|
||||
"VLLM_ASSETS_CACHE_MODEL_CLEAN",
|
||||
"VLLM_MM_INPUT_CACHE_GIB",
|
||||
"VLLM_WORKER_MULTIPROC_METHOD",
|
||||
"VLLM_ENABLE_V1_MULTIPROCESSING",
|
||||
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
|
||||
"VLLM_CPU_KVCACHE_SPACE",
|
||||
"VLLM_CPU_OMP_THREADS_BIND",
|
||||
"VLLM_CPU_NUM_OF_RESERVED_CPU",
|
||||
"VLLM_CPU_MOE_PREPACK",
|
||||
"VLLM_CPU_SGL_KERNEL",
|
||||
"VLLM_TEST_FORCE_LOAD_FORMAT",
|
||||
"LOCAL_RANK",
|
||||
"CUDA_VISIBLE_DEVICES",
|
||||
}
|
||||
|
||||
from vllm.config.utils import normalize_value
|
||||
|
||||
factors: dict[str, object] = {}
|
||||
for factor, getter in environment_variables.items():
|
||||
if factor in ignored_factors:
|
||||
continue
|
||||
|
||||
try:
|
||||
raw = getter()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.warning(
|
||||
"Skipping environment variable %s while hashing compile factors: %s",
|
||||
factor,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
factors[factor] = normalize_value(raw)
|
||||
|
||||
ray_noset_env_vars = [
|
||||
# Refer to
|
||||
@ -1641,8 +1646,8 @@ def compute_hash() -> str:
|
||||
"RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR",
|
||||
"RAY_EXPERIMENTAL_NOSET_RBLN_RT_VISIBLE_DEVICES",
|
||||
]
|
||||
factors.extend([os.getenv(var) for var in ray_noset_env_vars])
|
||||
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
for var in ray_noset_env_vars:
|
||||
factors[var] = normalize_value(os.getenv(var))
|
||||
|
||||
return hash_str
|
||||
return factors
|
||||
|
||||
@ -2,9 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.logging_utils.formatter import NewLineFormatter
|
||||
from vllm.logging_utils.lazy import lazy
|
||||
from vllm.logging_utils.log_time import logtime
|
||||
|
||||
__all__ = [
|
||||
"NewLineFormatter",
|
||||
"lazy",
|
||||
"logtime",
|
||||
]
|
||||
|
||||
20
vllm/logging_utils/lazy.py
Normal file
20
vllm/logging_utils/lazy.py
Normal file
@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
class lazy:
|
||||
"""Wrap a zero-argument callable evaluated only during log formatting."""
|
||||
|
||||
__slots__ = ("_factory",)
|
||||
|
||||
def __init__(self, factory: Callable[[], Any]) -> None:
|
||||
self._factory = factory
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self._factory())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
Loading…
x
Reference in New Issue
Block a user