mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[BugFix] Fix use_cudagraph=False (#19612)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
parent
d49adea1f9
commit
ed33349738
@ -1,14 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
|
||||
from .piecewise.test_simple import SillyModel
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
def test_use_cudagraphs_dynamic(monkeypatch):
|
||||
@ -22,23 +18,24 @@ def test_use_cudagraphs_dynamic(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enabled", [True, False])
|
||||
def test_use_cudagraphs(enabled):
|
||||
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
||||
assert vllm.envs.VLLM_USE_V1
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=enabled,
|
||||
cudagraph_capture_sizes=[100],
|
||||
))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SillyModel(vllm_config=vllm_config, prefix='')
|
||||
|
||||
inputs = torch.randn(100, device="cuda")
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_cudagraph_captured=1 if enabled else 0,
|
||||
):
|
||||
# first run is warmup
|
||||
model(inputs)
|
||||
# second run does CUDAGraphs recording (if enabled)
|
||||
model(inputs)
|
||||
compilation_config = {
|
||||
"cudagraph_capture_sizes": [100],
|
||||
"use_cudagraph": enabled,
|
||||
}
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_gpu_runner_capture_triggers=1 if enabled else 0,
|
||||
num_cudagraph_captured=13 if enabled else 0,
|
||||
),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner('facebook/opt-125m',
|
||||
compilation_config=compilation_config,
|
||||
gpu_memory_utilization=0.4) as _):
|
||||
pass
|
||||
|
||||
@ -15,6 +15,9 @@ class CompilationCounter:
|
||||
# not including the splitting ops
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_backend_compilations: int = 0
|
||||
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
||||
num_gpu_runner_capture_triggers: int = 0
|
||||
# Number of CUDAGraphs captured
|
||||
num_cudagraph_captured: int = 0
|
||||
# InductorAdapter.compile calls
|
||||
num_inductor_compiles: int = 0
|
||||
|
||||
@ -18,6 +18,7 @@ import vllm.envs as envs
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
@ -200,9 +201,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_sizes=[self.cache_config.block_size],
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (self.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
self.use_cuda_graph = (
|
||||
self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and self.vllm_config.compilation_config.use_cudagraph
|
||||
and not self.model_config.enforce_eager)
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
# The convention is different.
|
||||
# self.cudagraph_batch_sizes sorts in ascending order.
|
||||
@ -2058,10 +2061,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def capture_model(self) -> None:
|
||||
if not self.use_cuda_graph:
|
||||
logger.warning(
|
||||
"Skipping CUDA graph capture. Please add "
|
||||
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
|
||||
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
|
||||
"set -O %s and ensure `use_cudagraph` was not manually set to "
|
||||
"False", CompilationLevel.PIECEWISE)
|
||||
return
|
||||
|
||||
compilation_counter.num_gpu_runner_capture_triggers += 1
|
||||
|
||||
start_time = time.perf_counter()
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user