mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 03:47:03 +08:00
[torch.compile] fast inductor (#11108)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
c301616ed2
commit
88a412ed3d
@ -1,6 +1,10 @@
|
||||
import ast
|
||||
import copy
|
||||
import dataclasses
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from unittest.mock import patch
|
||||
@ -21,6 +25,122 @@ from .pass_manager import PostGradPassManager
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class InductorHashCache:
|
||||
"""
|
||||
Disk format: a Python list of tuples, each tuple is
|
||||
(runtime_shape, graph_index, hash_str)
|
||||
We use list of tuple for readability.
|
||||
|
||||
In-memory format: a defaultdict of dict, where the key is
|
||||
runtime_shape, and the value is a dict of graph_index to hash_str.
|
||||
|
||||
The data is essentially `Dict[Optional[int], Dict[int, str]]`,
|
||||
we don't use json here because json doesn't support int as key.
|
||||
|
||||
TODO: better off-the-shelf solution to serialize the data?
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str, disabled: bool = False):
|
||||
self.cache: defaultdict = defaultdict(dict)
|
||||
self.disabled = disabled
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_file_path = os.path.join(cache_dir,
|
||||
"inductor_hash_cache.py")
|
||||
if disabled:
|
||||
return
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
if os.path.exists(self.cache_file_path):
|
||||
with open(self.cache_file_path) as f:
|
||||
self.deserialize(f.read())
|
||||
|
||||
def deserialize(self, data: str):
|
||||
# we use ast.literal_eval to parse the data
|
||||
# because it is a safe way to parse Python literals.
|
||||
# do not use eval(), it is unsafe.
|
||||
list_data = ast.literal_eval(data)
|
||||
for runtime_shape, graph_index, hash_str in list_data:
|
||||
self.cache[runtime_shape][graph_index] = hash_str
|
||||
|
||||
def serialize(self) -> str:
|
||||
data = []
|
||||
for runtime_shape, graph_index_to_hash_str in self.cache.items():
|
||||
for graph_index, hash_str in graph_index_to_hash_str.items():
|
||||
data.append((runtime_shape, graph_index, hash_str))
|
||||
printer = pprint.PrettyPrinter(indent=4)
|
||||
return printer.pformat(data)
|
||||
|
||||
def save_to_file(self):
|
||||
if self.disabled:
|
||||
return
|
||||
with open(self.cache_file_path, "w") as f:
|
||||
f.write(self.serialize())
|
||||
|
||||
def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
|
||||
if self.disabled:
|
||||
return False
|
||||
runtime_shape, graph_index = key
|
||||
return runtime_shape in self.cache and graph_index in self.cache[
|
||||
runtime_shape]
|
||||
|
||||
def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
|
||||
if self.disabled:
|
||||
raise KeyError("cannot read from disabled cache")
|
||||
runtime_shape, graph_index = key
|
||||
return self.cache[runtime_shape][graph_index]
|
||||
|
||||
def __setitem__(self, key: Tuple[Optional[int], int], value: str):
|
||||
# setitem for disabled cache is fine, because we
|
||||
# don't actually write to the disk
|
||||
runtime_shape, graph_index = key
|
||||
self.cache[runtime_shape][graph_index] = value
|
||||
|
||||
|
||||
class AlwaysHitShapeEnv:
|
||||
"""
|
||||
Why do we need this class:
|
||||
|
||||
For normal `torch.compile` usage, every compilation will have
|
||||
one Dynamo bytecode compilation and one Inductor compilation.
|
||||
The Inductor compilation happens under the context of the
|
||||
Dynamo bytecode compilation, and that context is used to
|
||||
determine the dynamic shape information, etc.
|
||||
|
||||
For our use case, we only run Dynamo bytecode compilation once,
|
||||
and run Inductor compilation multiple times with different shapes
|
||||
plus a general shape. The compilation for specific shapes happens
|
||||
outside of the context of the Dynamo bytecode compilation. At that
|
||||
time, we don't have shape environment to provide to Inductor, and
|
||||
it will fail the Inductor code cache lookup.
|
||||
|
||||
By providing a dummy shape environment that always hits, we can
|
||||
make the Inductor code cache lookup always hit, and we can
|
||||
compile the graph for different shapes as needed.
|
||||
|
||||
The following dummy methods are obtained by trial-and-error
|
||||
until it works.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: List[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args, **kwargs):
|
||||
return ""
|
||||
|
||||
|
||||
def wrap_inductor(graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
@ -55,9 +175,93 @@ def wrap_inductor(graph,
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
compiled_graph = compile_fx(graph,
|
||||
example_inputs,
|
||||
config_patches=current_config)
|
||||
|
||||
cache_data = compilation_config.inductor_hash_cache
|
||||
if (runtime_shape, graph_index) in cache_data:
|
||||
# we compiled this graph before
|
||||
# so we can directly lookup the compiled graph via hash
|
||||
hash_str = cache_data[(runtime_shape, graph_index)]
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
logger.info(
|
||||
"Directly lookup the graph for shape %s from the cache",
|
||||
str(runtime_shape)) # noqa
|
||||
logger.debug(
|
||||
"directly lookup the %s-th graph for shape %s via hash %s",
|
||||
graph_index, str(runtime_shape), hash_str)
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()):
|
||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, False)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
# Dynamo calling convention (function signature):
|
||||
# f(*args) -> Any
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the graph we return to Dynamo to run
|
||||
def compiled_graph(*args):
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
# unpack the tuple if needed
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
else:
|
||||
# it's the first time we compile this graph
|
||||
# the assumption is that we don't have nested Inductor compilation.
|
||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||
# it to get the hash of the compiled graph directly.
|
||||
from torch._inductor.codecache import compiled_fx_graph_hash
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
# store the hash in the cache
|
||||
nonlocal cache_data
|
||||
cache_data[(runtime_shape, graph_index)] = out[0]
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
logger.info("Cache the graph of shape %s for later use",
|
||||
str(runtime_shape))
|
||||
logger.debug("store the %s-th graph for shape %s via hash %s",
|
||||
graph_index, str(runtime_shape), out[0])
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args, **kwargs):
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
# with high-order ops.
|
||||
# For vLLM, in either case, we want to cache the graph.
|
||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||
return
|
||||
|
||||
def _get_shape_env():
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with patch(# for hijacking the hash of the compiled graph
|
||||
"torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash), \
|
||||
patch(# for providing a dummy shape environment
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env), \
|
||||
patch(# for forcing the graph to be cached
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache):
|
||||
compiled_graph = compile_fx(graph,
|
||||
example_inputs,
|
||||
config_patches=current_config)
|
||||
|
||||
# after compiling the last graph, record the end time
|
||||
if graph_index == num_graphs - 1:
|
||||
@ -457,6 +661,9 @@ class PiecewiseBackend:
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.compilation_config.inductor_hash_cache.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
if not entry.use_cudagraph:
|
||||
|
||||
415
vllm/config.py
415
vllm/config.py
@ -3,6 +3,7 @@ import copy
|
||||
import enum
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field, replace
|
||||
@ -162,6 +163,30 @@ class ModelConfig:
|
||||
which allows no processors.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: List[Any] = []
|
||||
factors.append(self.model)
|
||||
factors.append(self.dtype)
|
||||
factors.append(self.quantization)
|
||||
factors.append(self.quantization_param_path)
|
||||
factors.append(self.revision)
|
||||
factors.append(self.code_revision)
|
||||
factors.append(self.trust_remote_code)
|
||||
factors.append(self.rope_scaling)
|
||||
factors.append(self.rope_theta)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
task: Union[TaskOption, Literal["draft"]],
|
||||
@ -203,6 +228,8 @@ class ModelConfig:
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.code_revision = code_revision
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
if hf_overrides is None:
|
||||
hf_overrides = {}
|
||||
@ -832,6 +859,24 @@ class CacheConfig:
|
||||
cpu_offload_gb: Size of the CPU offload buffer in GiB.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: List[Any] = []
|
||||
factors.append(self.cache_dtype)
|
||||
# `cpu_offload_gb` does not use `torch.compile` yet.
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
@ -928,6 +973,24 @@ class TokenizerPoolConfig:
|
||||
pool_type: Union[str, Type["BaseTokenizerGroup"]]
|
||||
extra_config: dict
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pool_type not in ("ray", ) and not isinstance(
|
||||
self.pool_type, type):
|
||||
@ -1010,6 +1073,24 @@ class LoadConfig:
|
||||
default_factory=dict)
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
if isinstance(model_loader_extra_config, str):
|
||||
@ -1073,6 +1154,19 @@ class ParallelConfig:
|
||||
|
||||
rank: int = 0
|
||||
|
||||
def compute_hash(self):
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: List[Any] = []
|
||||
factors.append(self.pipeline_parallel_size)
|
||||
factors.append(self.tensor_parallel_size)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.world_size = self.pipeline_parallel_size * \
|
||||
self.tensor_parallel_size
|
||||
@ -1209,6 +1303,24 @@ class SchedulerConfig:
|
||||
|
||||
chunked_prefill_enabled: bool = field(init=False)
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.max_num_batched_tokens is None:
|
||||
if self.enable_chunked_prefill:
|
||||
@ -1286,6 +1398,25 @@ class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
device_type: str
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# the device/platform information will be summarized
|
||||
# by torch/vllm automatically.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __init__(self, device: str = "auto") -> None:
|
||||
if device == "auto":
|
||||
# Automated device type detection
|
||||
@ -1313,6 +1444,24 @@ class SpeculativeConfig:
|
||||
decoding with top-1 proposals.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# spec decode does not use `torch.compile` yet.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@staticmethod
|
||||
def maybe_create_spec_config(
|
||||
target_model_config: ModelConfig,
|
||||
@ -1753,6 +1902,24 @@ class LoRAConfig:
|
||||
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||
bias_enabled: bool = False
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# LoRA is not compatible with `torch.compile` .
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
# Setting the maximum rank to 256 should be able to satisfy the vast
|
||||
# majority of applications.
|
||||
@ -1802,6 +1969,24 @@ class PromptAdapterConfig:
|
||||
max_cpu_prompt_adapters: Optional[int] = None
|
||||
prompt_adapter_dtype: Optional[torch.dtype] = None
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.max_prompt_adapters < 1:
|
||||
@ -1830,6 +2015,24 @@ class MultiModalConfig:
|
||||
for each :class:`~vllm.multimodal.MultiModalPlugin`.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
# TODO: Add configs to init vision tower or not.
|
||||
|
||||
|
||||
@ -1869,6 +2072,24 @@ class PoolerConfig:
|
||||
``math-shepherd-mistral-7b-prm`` model.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@staticmethod
|
||||
def from_json(json_str: str) -> "PoolerConfig":
|
||||
return PoolerConfig(**json.loads(json_str))
|
||||
@ -2103,6 +2324,24 @@ class DecodingConfig:
|
||||
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
|
||||
guided_decoding_backend: str = 'xgrammar'
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
|
||||
backend = self.guided_decoding_backend
|
||||
@ -2124,6 +2363,24 @@ class ObservabilityConfig:
|
||||
# If set, collects the model execute time for the request.
|
||||
collect_model_execute_time: bool = False
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if not is_otel_available() and self.otlp_traces_endpoint is not None:
|
||||
raise ValueError(
|
||||
@ -2165,6 +2422,24 @@ class KVTransferConfig(BaseModel):
|
||||
# The KV connector port, used to build distributed connection
|
||||
kv_port: int = 14579
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: List[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
|
||||
"""Parse the CLI value for the kv cache transfer config."""
|
||||
@ -2234,6 +2509,9 @@ class CompilationConfig(BaseModel):
|
||||
- 2: dynamo once.
|
||||
- 3: piecewise compilation.
|
||||
- debug_dump_path: the path to dump the debug information.
|
||||
- cache_dir: the directory to store the compiled graph, to
|
||||
accelerate Inductor compilation. By default, it will use
|
||||
model-related information to generate a cache directory.
|
||||
- backend: the backend for compilation. It needs to be a string.
|
||||
- "" (empty string): use the default backend.
|
||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||
@ -2302,12 +2580,10 @@ class CompilationConfig(BaseModel):
|
||||
""" # noqa
|
||||
level: int = 0
|
||||
debug_dump_path: str = ""
|
||||
cache_dir: str = ""
|
||||
backend: str = ""
|
||||
custom_ops: List[str] = Field(default_factory=list)
|
||||
splitting_ops: List[str] = Field(default_factory=lambda: [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
])
|
||||
splitting_ops: List[str] = Field(default=None) # type: ignore
|
||||
|
||||
use_inductor: bool = True
|
||||
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
|
||||
@ -2371,12 +2647,37 @@ class CompilationConfig(BaseModel):
|
||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||
disabled_custom_ops: Counter[str] = PrivateAttr
|
||||
compilation_time: float = PrivateAttr
|
||||
# should be InductorHashCache, but Pydantic does not support it
|
||||
inductor_hash_cache: Any = PrivateAttr
|
||||
|
||||
# Per-model forward context
|
||||
# Mainly used to store attention cls
|
||||
# Map from layer name to the attention cls
|
||||
static_forward_context: Dict[str, Any] = PrivateAttr
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: List[Any] = []
|
||||
factors.append(self.level)
|
||||
factors.append(self.backend)
|
||||
factors.append(self.custom_ops)
|
||||
factors.append(self.splitting_ops)
|
||||
factors.append(self.use_inductor)
|
||||
factors.append(self.inductor_compile_config)
|
||||
factors.append(self.inductor_passes)
|
||||
factors.append(self.pass_config.uuid())
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
exclude = {
|
||||
"static_forward_context",
|
||||
@ -2405,6 +2706,27 @@ class CompilationConfig(BaseModel):
|
||||
count_all = self.custom_ops.count("all")
|
||||
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
||||
|
||||
if self.splitting_ops is None:
|
||||
if envs.VLLM_USE_V1:
|
||||
# v1 must split the graph on attention ops
|
||||
# for piecewise cudagraph
|
||||
self.splitting_ops = [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
]
|
||||
else:
|
||||
# v0 can use full graph compilation without splitting,
|
||||
# splitting is optional.
|
||||
# right now we still need it. kv cache shape
|
||||
# will be included in the graph if we don't split
|
||||
# the graph.
|
||||
# TODO: hide kv cache in static forward context
|
||||
# so that inductor does not see it.
|
||||
self.splitting_ops = [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
]
|
||||
|
||||
for k, v in self.inductor_passes.items():
|
||||
if not isinstance(v, str):
|
||||
assert callable(v), (
|
||||
@ -2444,6 +2766,30 @@ class CompilationConfig(BaseModel):
|
||||
# TODO: pass user-specified backend to piecewise compilation
|
||||
# merge with the config use_inductor
|
||||
assert self.level == CompilationLevel.PIECEWISE
|
||||
|
||||
if not self.cache_dir:
|
||||
# no provided cache dir, generate one based on the known factors
|
||||
# that affects the compilation. if none of the factors change,
|
||||
# the cache dir will be the same so that we can reuse the compiled
|
||||
# graph.
|
||||
hash_key = vllm_config.compute_hash()
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
|
||||
f"rank_{vllm_config.parallel_config.rank}")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
from vllm.compilation.backends import InductorHashCache
|
||||
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
|
||||
self.cache_dir, disabled=disabled)
|
||||
if disabled:
|
||||
logger.info("vLLM's torch.compile cache is disabled.")
|
||||
else:
|
||||
logger.info(
|
||||
"Using cache directory: %s for vLLM's torch.compile",
|
||||
self.cache_dir)
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
return VllmBackend(vllm_config)
|
||||
|
||||
@ -2520,6 +2866,67 @@ class VllmConfig:
|
||||
init=True) # type: ignore
|
||||
instance_id: str = ""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: List[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
|
||||
# summarize vllm config
|
||||
vllm_factors: List[Any] = []
|
||||
from vllm import __version__
|
||||
vllm_factors.append(__version__)
|
||||
if self.model_config:
|
||||
vllm_factors.append(self.model_config.compute_hash())
|
||||
if self.cache_config:
|
||||
vllm_factors.append(self.cache_config.compute_hash())
|
||||
if self.parallel_config:
|
||||
vllm_factors.append(self.parallel_config.compute_hash())
|
||||
if self.scheduler_config:
|
||||
vllm_factors.append(self.scheduler_config.compute_hash())
|
||||
if self.device_config:
|
||||
vllm_factors.append(self.device_config.compute_hash())
|
||||
if self.load_config:
|
||||
vllm_factors.append(self.load_config.compute_hash())
|
||||
if self.lora_config:
|
||||
vllm_factors.append(self.lora_config.compute_hash())
|
||||
if self.speculative_config:
|
||||
vllm_factors.append(self.speculative_config.compute_hash())
|
||||
if self.decoding_config:
|
||||
vllm_factors.append(self.decoding_config.compute_hash())
|
||||
if self.observability_config:
|
||||
vllm_factors.append(self.observability_config.compute_hash())
|
||||
if self.prompt_adapter_config:
|
||||
vllm_factors.append(self.prompt_adapter_config.compute_hash())
|
||||
if self.quant_config:
|
||||
pass # should be captured by model_config.quantization
|
||||
if self.compilation_config:
|
||||
vllm_factors.append(self.compilation_config.compute_hash())
|
||||
if self.kv_transfer_config:
|
||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||
|
||||
factors.append(vllm_factors)
|
||||
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def pad_for_cudagraph(self, batch_size: int) -> int:
|
||||
# if batch_size > self.compilation_config.max_capture_size,
|
||||
# it should raise an IndexError.
|
||||
|
||||
@ -71,6 +71,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_V1: bool = False
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -463,6 +464,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
|
||||
"VLLM_LOG_BATCHSIZE_INTERVAL":
|
||||
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
|
||||
"VLLM_DISABLE_COMPILE_CACHE":
|
||||
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user