mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 22:35:46 +08:00
[BUG] Make 'binary' default option for saving torch compile artifacts when using standalone_compile (#27616)
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
This commit is contained in:
parent
f7d2946e99
commit
cac4c10ef0
@ -27,6 +27,8 @@ With all these factors taken into consideration, usually we can guarantee that t
|
|||||||
|
|
||||||
A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all the compilation finishes before we serve any requests. No requests will trigger new compilations. Otherwise, the engine would be blocked on that request, and the response time will have unexpected spikes.
|
A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all the compilation finishes before we serve any requests. No requests will trigger new compilations. Otherwise, the engine would be blocked on that request, and the response time will have unexpected spikes.
|
||||||
|
|
||||||
|
By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`.
|
||||||
|
|
||||||
## Python Code Compilation
|
## Python Code Compilation
|
||||||
|
|
||||||
In the very verbose logs, we can see:
|
In the very verbose logs, we can see:
|
||||||
|
|||||||
@ -51,7 +51,9 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
|||||||
and hasattr(torch._inductor, "standalone_compile")
|
and hasattr(torch._inductor, "standalone_compile")
|
||||||
):
|
):
|
||||||
logger.debug("Using InductorStandaloneAdaptor")
|
logger.debug("Using InductorStandaloneAdaptor")
|
||||||
return InductorStandaloneAdaptor()
|
return InductorStandaloneAdaptor(
|
||||||
|
compilation_config.compile_cache_save_format
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("Using InductorAdaptor")
|
logger.debug("Using InductorAdaptor")
|
||||||
return InductorAdaptor()
|
return InductorAdaptor()
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import hashlib
|
|||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -175,6 +175,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
|
|
||||||
name = "inductor_standalone"
|
name = "inductor_standalone"
|
||||||
|
|
||||||
|
def __init__(self, save_format: Literal["binary", "unpacked"]):
|
||||||
|
self.save_format = save_format
|
||||||
|
|
||||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||||
factors = get_inductor_factors()
|
factors = get_inductor_factors()
|
||||||
hash_str = hashlib.md5(
|
hash_str = hashlib.md5(
|
||||||
@ -220,7 +223,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
assert key is not None
|
assert key is not None
|
||||||
path = os.path.join(self.cache_dir, key)
|
path = os.path.join(self.cache_dir, key)
|
||||||
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
||||||
compiled_graph.save(path=path, format="unpacked")
|
compiled_graph.save(path=path, format=self.save_format)
|
||||||
compilation_counter.num_compiled_artifacts_saved += 1
|
compilation_counter.num_compiled_artifacts_saved += 1
|
||||||
return compiled_graph, (key, path)
|
return compiled_graph, (key, path)
|
||||||
|
|
||||||
@ -237,7 +240,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
assert isinstance(handle[1], str)
|
assert isinstance(handle[1], str)
|
||||||
path = handle[1]
|
path = handle[1]
|
||||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||||
path=path, format="unpacked"
|
path=path, format=self.save_format
|
||||||
)
|
)
|
||||||
from torch._inductor.compile_fx import graph_returns_tuple
|
from torch._inductor.compile_fx import graph_returns_tuple
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,12 @@ from collections import Counter
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import asdict, field
|
from dataclasses import asdict, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||||
|
|
||||||
from pydantic import TypeAdapter, field_validator
|
from pydantic import TypeAdapter, field_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -208,6 +209,15 @@ class CompilationConfig:
|
|||||||
"""The directory to store the compiled graph, to accelerate Inductor
|
"""The directory to store the compiled graph, to accelerate Inductor
|
||||||
compilation. By default, it will use model-related information to generate
|
compilation. By default, it will use model-related information to generate
|
||||||
a cache directory."""
|
a cache directory."""
|
||||||
|
compile_cache_save_format: Literal["binary", "unpacked"] = field(
|
||||||
|
default_factory=lambda: envs.VLLM_COMPILE_CACHE_SAVE_FORMAT
|
||||||
|
)
|
||||||
|
"""Format for saving torch compile cache:\n
|
||||||
|
- "binary": saves as binary file (multiprocess safe)\n
|
||||||
|
- "unpacked": saves as directory structure for inspection/debugging
|
||||||
|
(NOT multiprocess safe)\n
|
||||||
|
Defaults to `VLLM_COMPILE_CACHE_SAVE_FORMAT` if not specified.
|
||||||
|
"""
|
||||||
backend: str = ""
|
backend: str = ""
|
||||||
"""The backend for compilation. It needs to be a string:
|
"""The backend for compilation. It needs to be a string:
|
||||||
|
|
||||||
@ -479,6 +489,7 @@ class CompilationConfig:
|
|||||||
factors.append(self.inductor_compile_config)
|
factors.append(self.inductor_compile_config)
|
||||||
factors.append(self.inductor_passes)
|
factors.append(self.inductor_passes)
|
||||||
factors.append(self.pass_config.uuid())
|
factors.append(self.pass_config.uuid())
|
||||||
|
factors.append(self.compile_cache_save_format)
|
||||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -520,6 +531,16 @@ class CompilationConfig:
|
|||||||
return CUDAGraphMode[value.upper()]
|
return CUDAGraphMode[value.upper()]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@field_validator("compile_cache_save_format")
|
||||||
|
@classmethod
|
||||||
|
def validate_compile_cache_save_format(cls, value: str) -> str:
|
||||||
|
if value not in ("binary", "unpacked"):
|
||||||
|
raise ValueError(
|
||||||
|
f"compile_cache_save_format must be 'binary' or 'unpacked', "
|
||||||
|
f"got: {value}"
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.level is not None:
|
if self.level is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
10
vllm/envs.py
10
vllm/envs.py
@ -218,6 +218,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_USE_FBGEMM: bool = False
|
VLLM_USE_FBGEMM: bool = False
|
||||||
VLLM_GC_DEBUG: str = ""
|
VLLM_GC_DEBUG: str = ""
|
||||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||||
|
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -1442,6 +1443,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv(
|
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv(
|
||||||
"VLLM_DISABLE_SHARED_EXPERTS_STREAM", False
|
"VLLM_DISABLE_SHARED_EXPERTS_STREAM", False
|
||||||
),
|
),
|
||||||
|
# Format for saving torch.compile cache artifacts
|
||||||
|
# - "binary": saves as binary file
|
||||||
|
# Safe for multiple vllm serve processes accessing the same torch compile cache.
|
||||||
|
# - "unpacked": saves as directory structure (for inspection/debugging)
|
||||||
|
# NOT multiprocess safe - race conditions may occur with multiple processes.
|
||||||
|
# Allows viewing and setting breakpoints in Inductor's code output files.
|
||||||
|
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
|
||||||
|
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user