[BugFix] Fix use_cudagraph=False (#19612)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou 2025-06-18 20:23:12 -04:00 committed by GitHub
parent d49adea1f9
commit ed33349738
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 33 additions and 27 deletions

View File

@ -1,14 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import torch
import vllm import vllm
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, from vllm.config import VllmConfig
set_current_vllm_config)
from .piecewise.test_simple import SillyModel
def test_use_cudagraphs_dynamic(monkeypatch): def test_use_cudagraphs_dynamic(monkeypatch):
@ -22,23 +18,24 @@ def test_use_cudagraphs_dynamic(monkeypatch):
@pytest.mark.parametrize("enabled", [True, False]) @pytest.mark.parametrize("enabled", [True, False])
def test_use_cudagraphs(enabled): def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
assert vllm.envs.VLLM_USE_V1 assert vllm.envs.VLLM_USE_V1
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=enabled,
cudagraph_capture_sizes=[100],
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
inputs = torch.randn(100, device="cuda") # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
with compilation_counter.expect( compilation_config = {
num_graphs_seen=1, # one graph for the model "cudagraph_capture_sizes": [100],
num_cudagraph_captured=1 if enabled else 0, "use_cudagraph": enabled,
): }
# first run is warmup with (
model(inputs) compilation_counter.expect(
# second run does CUDAGraphs recording (if enabled) num_graphs_seen=1,
model(inputs) num_gpu_runner_capture_triggers=1 if enabled else 0,
num_cudagraph_captured=13 if enabled else 0,
),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config=compilation_config,
gpu_memory_utilization=0.4) as _):
pass

View File

@ -15,6 +15,9 @@ class CompilationCounter:
# not including the splitting ops # not including the splitting ops
num_piecewise_capturable_graphs_seen: int = 0 num_piecewise_capturable_graphs_seen: int = 0
num_backend_compilations: int = 0 num_backend_compilations: int = 0
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
num_gpu_runner_capture_triggers: int = 0
# Number of CUDAGraphs captured
num_cudagraph_captured: int = 0 num_cudagraph_captured: int = 0
# InductorAdapter.compile calls # InductorAdapter.compile calls
num_inductor_compiles: int = 0 num_inductor_compiles: int = 0

View File

@ -18,6 +18,7 @@ import vllm.envs as envs
from vllm.attention import AttentionType, get_attn_backend from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@ -200,9 +201,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_sizes=[self.cache_config.block_size], block_sizes=[self.cache_config.block_size],
) )
self.use_cuda_graph = (self.compilation_config.level self.use_cuda_graph = (
== CompilationLevel.PIECEWISE self.vllm_config.compilation_config.level
and not self.model_config.enforce_eager) == CompilationLevel.PIECEWISE
and self.vllm_config.compilation_config.use_cudagraph
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different. # The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order. # self.cudagraph_batch_sizes sorts in ascending order.
@ -2058,10 +2061,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def capture_model(self) -> None: def capture_model(self) -> None:
if not self.use_cuda_graph: if not self.use_cuda_graph:
logger.warning( logger.warning(
"Skipping CUDA graph capture. Please add " "Skipping CUDA graph capture. To turn on CUDA graph capture, "
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) "set -O %s and ensure `use_cudagraph` was not manually set to "
"False", CompilationLevel.PIECEWISE)
return return
compilation_counter.num_gpu_runner_capture_triggers += 1
start_time = time.perf_counter() start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0] start_free_gpu_memory = torch.cuda.mem_get_info()[0]