[6/N] torch.compile rollout to users (#10437)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-19 10:09:03 -08:00 committed by GitHub
parent fd9f124971
commit 803f37eaaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 129 additions and 141 deletions

View File

@ -1,5 +0,0 @@
{
"use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"],
"cudagraph_copy_inputs": true
}

View File

@ -2,7 +2,6 @@
Test the piecewise compilation with a simple model so that we Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects. can exactly calculate the expected output and side effects.
""" """
import os
import torch import torch
from torch import nn from torch import nn
@ -11,7 +10,7 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
@ -77,12 +76,12 @@ class SillyModel(nn.Module):
def test_simple_piecewise_compile(): def test_simple_piecewise_compile():
directory = os.path.dirname(__file__) vllm_config = VllmConfig(compilation_config=CompilationConfig(
config = os.path.join(directory, "piecewise_compilation_config.json") level=CompilationLevel.PIECEWISE,
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config use_cudagraph=True,
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) non_cudagraph_ops=["silly.attention"],
cudagraph_copy_inputs=True,
vllm_config = VllmConfig() ))
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='') model = SillyModel(vllm_config=vllm_config, prefix='')
@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
output = model(input) output = model(input)
assert global_counter == 2 assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
# clean up to avoid side effects for other tests
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]

View File

@ -6,7 +6,6 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed. initialized randomly with a fixed seed.
""" """
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@ -18,7 +17,7 @@ from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_compilation_config, set_current_vllm_config from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
# create a library to hold the custom op # create a library to hold the custom op
@ -254,23 +253,17 @@ def run_model(llama_config,
split_attn: bool = False) -> torch.Tensor: split_attn: bool = False) -> torch.Tensor:
if use_compile: if use_compile:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str( compilation_config = CompilationConfig(
CompilationLevel.PIECEWISE) level=CompilationLevel.PIECEWISE,
if split_attn:
set_compilation_config(
CompilationConfig(
use_cudagraph=True, use_cudagraph=True,
non_cudagraph_ops=["silly.attention"], )
)) if split_attn:
compilation_config.non_cudagraph_ops = ["silly.attention"]
else: else:
set_compilation_config(CompilationConfig(use_cudagraph=True, )) compilation_config = CompilationConfig(
else: level=CompilationLevel.NO_COMPILATION, )
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)
vllm_config = VllmConfig() vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config, model = LlamaModel(config=llama_config,
vllm_config=vllm_config, vllm_config=vllm_config,
@ -288,10 +281,6 @@ def run_model(llama_config,
input_ids[:2].zero_() input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2]) output = model(input_ids[:2], positions[:2])
# manual cleanup
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
set_compilation_config(None)
output = output.cpu() output = output.cpu()
if llama_config.tractable_init: if llama_config.tractable_init:
@ -361,7 +350,6 @@ def test_toy_llama():
@torch.inference_mode @torch.inference_mode
def benchmark(): def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench from triton.testing import do_bench
# similar to llama 3.1-8B # similar to llama 3.1-8B
@ -387,15 +375,16 @@ def benchmark():
for piecewise in [False, True]: for piecewise in [False, True]:
if piecewise: if piecewise:
set_compilation_config( compilation_config = CompilationConfig(
CompilationConfig( level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
non_cudagraph_ops=["silly.attention"], non_cudagraph_ops=["silly.attention"],
)) )
else: else:
set_compilation_config(None) compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, )
vllm_config = VllmConfig() vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config, model = LlamaModel(config=llama_config,
vllm_config=vllm_config, vllm_config=vllm_config,

View File

@ -96,31 +96,36 @@ def test_compile_correctness(test_setting: TestSetting):
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \ final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
["-tp", str(tp_size)] ["-tp", str(tp_size)]
all_args: List[List[str]] = []
all_envs: List[Optional[Dict[str, str]]] = [] all_envs: List[Optional[Dict[str, str]]] = []
for level in [ for level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE, CompilationLevel.PIECEWISE,
]: ]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)}) all_args.append(final_args + ["-O", str(level)])
all_envs.append({})
# inductor will change the output, so we only compare if the output # inductor will change the output, so we only compare if the output
# is close, not exactly the same. # is close, not exactly the same.
compare_all_settings( compare_all_settings(
model, [final_args] * 2, model,
all_args,
all_envs, all_envs,
method=method if method != "generate" else "generate_close") method=method if method != "generate" else "generate_close")
all_envs.clear() all_envs.clear()
all_args.clear()
for level in [ for level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE, CompilationLevel.DYNAMO_ONCE,
]: ]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)}) all_args.append(final_args + ["-O", str(level)])
all_envs.append({})
if level != CompilationLevel.DYNAMO_ONCE and not fullgraph: if level != CompilationLevel.DYNAMO_ONCE and not fullgraph:
# "DYNAMO_ONCE" will always use fullgraph # "DYNAMO_ONCE" will always use fullgraph
all_envs[-1][ all_envs[-1][
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore
compare_all_settings(model, [final_args] * 3, all_envs, method=method) compare_all_settings(model, all_args * 3, all_envs, method=method)

View File

@ -4,7 +4,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationLevel from vllm.config import CompilationConfig, CompilationLevel
from vllm.platforms import current_platform from vllm.platforms import current_platform
TEST_MODELS = [ TEST_MODELS = [
@ -65,7 +65,6 @@ def check_full_graph_support(model,
optimization_level, optimization_level,
tp_size=1): tp_size=1):
# make sure these models can be captured in full graph mode # make sure these models can be captured in full graph mode
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
# The base meta llama uses too much memory. # The base meta llama uses too much memory.
@ -86,6 +85,7 @@ def check_full_graph_support(model,
enforce_eager=True, enforce_eager=True,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
compilation_config=CompilationConfig(level=optimization_level),
**model_kwargs) **model_kwargs)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)

View File

@ -1,4 +1,3 @@
import os
from typing import List from typing import List
import pytest import pytest
@ -53,9 +52,8 @@ class Relu3(ReLUSquaredActivation):
]) ])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool): default_on: bool):
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(","))) level=torch_level, custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on assert CustomOp.default_on() == default_on

View File

@ -1,24 +1,47 @@
import glob import glob
import os import os
import runpy
import tempfile import tempfile
import depyf import depyf
from vllm.config import CompilationLevel from vllm.config import CompilationConfig, CompilationLevel
# disable custom dispatcher, let Dynamo takes over
# all the control
os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir): with depyf.prepare_debug(temp_dir):
cur_dir = os.path.dirname(__file__) from vllm import LLM, SamplingParams
parent_dir = os.path.dirname(cur_dir)
root_dir = os.path.dirname(parent_dir) prompts = [
example_file = os.path.join(root_dir, "examples", "A robot may not injure a human being",
"offline_inference_tpu.py") "It is only with the heart that one can see rightly;",
runpy.run_path(example_file) "The greatest glory in living lies not in never falling,",
]
answers = [
" or, through inaction, allow a human being to come to harm.",
" what is essential is invisible to the eye.",
" but in rising every time we fall.",
]
N = 1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params = SamplingParams(temperature=0.7,
top_p=1.0,
n=N,
max_tokens=16)
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
# disable custom dispatcher, let Dynamo takes over
# all the control
llm = LLM(model="google/gemma-2b",
enforce_eager=True,
compilation_config=CompilationConfig(
level=CompilationLevel.DYNAMO_AS_IS))
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert generated_text.startswith(answer)
compiled_code = sorted( compiled_code = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))

View File

@ -13,7 +13,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def test_custom_dispatcher(): def test_custom_dispatcher():
compare_two_settings( compare_two_settings(
"google/gemma-2b", "google/gemma-2b",
arg1=["--enforce-eager"], arg1=["--enforce-eager", "-O",
arg2=["--enforce-eager"], str(CompilationLevel.DYNAMO_ONCE)],
env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)}, arg2=["--enforce-eager", "-O",
env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)}) str(CompilationLevel.DYNAMO_AS_IS)],
env1={},
env2={})

View File

@ -2174,8 +2174,14 @@ class CompilationConfig(BaseModel):
enabled_custom_ops: Counter[str] = PrivateAttr enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr
@classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value))
return CompilationConfig.model_validate_json(cli_value)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
self.level = envs.VLLM_TORCH_COMPILE_LEVEL
count_none = self.custom_ops.count("none") count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all") count_all = self.custom_ops.count("all")
@ -2249,26 +2255,6 @@ class CompilationConfig(BaseModel):
"inductor_specialize_for_cudagraph_no_more_than is None") "inductor_specialize_for_cudagraph_no_more_than is None")
self.compile_sizes = self.inductor_compile_sizes self.compile_sizes = self.inductor_compile_sizes
@staticmethod
def select_and_init_config() -> "CompilationConfig":
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
if config_path is not None:
with open(config_path) as json_file:
config = CompilationConfig.model_validate_json(
json_file.read())
else:
from vllm.plugins import get_compilation_config
predefined_config = get_compilation_config()
config = predefined_config if predefined_config is not None else (
CompilationConfig())
return config
@dataclass @dataclass
class VllmConfig: class VllmConfig:
@ -2354,8 +2340,19 @@ class VllmConfig:
self.model_config, self.load_config) self.model_config, self.load_config)
if self.compilation_config is None: if self.compilation_config is None:
self.compilation_config = CompilationConfig.select_and_init_config( self.compilation_config = CompilationConfig()
) if envs.VLLM_USE_V1:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.non_cudagraph_ops = [
"vllm.unified_v1_flash_attention"
]
self.compilation_config.use_inductor = True
self.compilation_config.enable_fusion = False
current_platform.check_and_update_config(self) current_platform.check_and_update_config(self)

View File

@ -8,12 +8,13 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DeviceConfig, HfOverrides, LoadConfig, LoadFormat, DecodingConfig, DeviceConfig, HfOverrides, LoadConfig,
LoRAConfig, ModelConfig, ObservabilityConfig, LoadFormat, LoRAConfig, ModelConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig, ObservabilityConfig, ParallelConfig, PoolerConfig,
SchedulerConfig, SpeculativeConfig, TaskOption, PromptAdapterConfig, SchedulerConfig,
TokenizerPoolConfig, VllmConfig) SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -189,6 +190,7 @@ class EngineArgs:
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None
def __post_init__(self): def __post_init__(self):
if not self.tokenizer: if not self.tokenizer:
@ -868,6 +870,20 @@ class EngineArgs:
help="Override or set the pooling method in the embedding model. " help="Override or set the pooling method in the embedding model. "
"e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'") "e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'")
parser.add_argument('--compilation-config',
'-O',
type=CompilationConfig.from_cli,
default=None,
help='torch.compile configuration for the model.'
'When it is a number (0, 1, 2, 3), it will be '
'interpreted as the optimization level.\n'
'NOTE: level 0 is the default level without '
'any optimization. level 1 and 2 are for internal '
'testing only. level 3 is the recommended level '
'for production.\n'
'To specify the full compilation config, '
'use a JSON string.')
return parser return parser
@classmethod @classmethod
@ -1142,6 +1158,7 @@ class EngineArgs:
decoding_config=decoding_config, decoding_config=decoding_config,
observability_config=observability_config, observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config, prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config,
) )

View File

@ -262,7 +262,8 @@ class LLMEngine:
"num_scheduler_steps=%d, chunked_prefill_enabled=%s " "num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s, pooler_config=%r)", "mm_processor_kwargs=%s, pooler_config=%r,"
"compilation_config=%r",
VLLM_VERSION, VLLM_VERSION,
model_config.model, model_config.model,
speculative_config, speculative_config,
@ -297,6 +298,7 @@ class LLMEngine:
use_cached_outputs, use_cached_outputs,
model_config.mm_processor_kwargs, model_config.mm_processor_kwargs,
model_config.pooler_config, model_config.pooler_config,
vllm_config.compilation_config,
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config

View File

@ -67,8 +67,6 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_AWQ: bool = False VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None
VLLM_DISABLED_KERNELS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
@ -209,12 +207,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool( lambda: bool(
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
# Path to the config file for torch compile
"VLLM_TORCH_COMPILE_CONFIG":
lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id

View File

@ -1,4 +1,3 @@
import os
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
@ -40,7 +39,8 @@ class TpuPlatform(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel from vllm.config import CompilationLevel
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION
compilation_config.level = CompilationLevel.DYNAMO_ONCE compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\ assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor." "TPU does not support Inductor."

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
import vllm.envs as envs import vllm.envs as envs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import CompilationConfig, VllmConfig from vllm.config import VllmConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -54,18 +54,6 @@ def load_general_plugins():
logger.exception("Failed to load plugin %s", plugin.name) logger.exception("Failed to load plugin %s", plugin.name)
_compilation_config: Optional["CompilationConfig"] = None
def set_compilation_config(config: Optional["CompilationConfig"]):
global _compilation_config
_compilation_config = config
def get_compilation_config() -> Optional["CompilationConfig"]:
return _compilation_config
_current_vllm_config: Optional["VllmConfig"] = None _current_vllm_config: Optional["VllmConfig"] = None

View File

@ -8,13 +8,12 @@ import torch.distributed
import torch.nn as nn import torch.nn as nn
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.plugins import set_compilation_config
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
is_pin_memory_available) is_pin_memory_available)
@ -508,20 +507,6 @@ class GPUModelRunner:
return model_runner_output return model_runner_output
def load_model(self) -> None: def load_model(self) -> None:
if self.use_cuda_graph:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
set_compilation_config(
CompilationConfig(
custom_ops=["none"],
use_cudagraph=True,
non_cudagraph_ops=["vllm.unified_v1_flash_attention"],
use_inductor=True,
enable_fusion=False,
))
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117 with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
@ -562,9 +547,8 @@ class GPUModelRunner:
def capture_model(self) -> None: def capture_model(self) -> None:
if not self.use_cuda_graph: if not self.use_cuda_graph:
logger.warning( logger.warning(
"Skipping CUDA graph capture. Please set " "Skipping CUDA graph capture. Please add "
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.", "-O 3 to use CUDA graphs.", CompilationLevel.PIECEWISE)
CompilationLevel.PIECEWISE)
return return
start_time = time.perf_counter() start_time = time.perf_counter()