mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:38:38 +08:00
[Bugfix] Re-enable use_cudagraph in vLLM v1 (#19299)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
parent
d77f7fb871
commit
eaa2e51088
@ -95,7 +95,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
|
|||||||
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||||
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_caputured=
|
num_cudagraph_captured=
|
||||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|||||||
@ -327,7 +327,7 @@ def _test_toy_llama(*, use_inductor):
|
|||||||
num_piecewise_graphs_seen=0,
|
num_piecewise_graphs_seen=0,
|
||||||
num_piecewise_capturable_graphs_seen=0,
|
num_piecewise_capturable_graphs_seen=0,
|
||||||
num_backend_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_caputured=0,
|
num_cudagraph_captured=0,
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
run_model(llama_config, use_inductor=False, use_compile=False))
|
run_model(llama_config, use_inductor=False, use_compile=False))
|
||||||
@ -343,7 +343,7 @@ def _test_toy_llama(*, use_inductor):
|
|||||||
num_piecewise_graphs_seen=1,
|
num_piecewise_graphs_seen=1,
|
||||||
num_piecewise_capturable_graphs_seen=1,
|
num_piecewise_capturable_graphs_seen=1,
|
||||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_caputured=
|
num_cudagraph_captured=
|
||||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -361,7 +361,7 @@ def _test_toy_llama(*, use_inductor):
|
|||||||
llama_config.num_layers, # 1 + num_layers
|
llama_config.num_layers, # 1 + num_layers
|
||||||
num_backend_compilations=1 +
|
num_backend_compilations=1 +
|
||||||
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_caputured=2 *
|
num_cudagraph_captured=2 *
|
||||||
(1 + llama_config.num_layers
|
(1 + llama_config.num_layers
|
||||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
|
|||||||
43
tests/compile/test_config.py
Normal file
43
tests/compile/test_config.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.compilation.counter import compilation_counter
|
||||||
|
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||||
|
set_current_vllm_config)
|
||||||
|
|
||||||
|
from .piecewise.test_simple import SillyModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def use_v1(monkeypatch):
|
||||||
|
"""
|
||||||
|
TODO(rzou): The rest of tests/compile runs VLLM_USE_V1=0 right now,
|
||||||
|
I'll switch them over later.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enabled", [True, False])
|
||||||
|
def test_use_cudagraphs(enabled):
|
||||||
|
assert vllm.envs.VLLM_USE_V1
|
||||||
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
use_cudagraph=enabled,
|
||||||
|
cudagraph_capture_sizes=[100],
|
||||||
|
))
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
model = SillyModel(vllm_config=vllm_config, prefix='')
|
||||||
|
|
||||||
|
inputs = torch.randn(100, device="cuda")
|
||||||
|
|
||||||
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=1, # one graph for the model
|
||||||
|
num_cudagraph_captured=1 if enabled else 0,
|
||||||
|
):
|
||||||
|
# first run is warmup
|
||||||
|
model(inputs)
|
||||||
|
# second run does CUDAGraphs recording (if enabled)
|
||||||
|
model(inputs)
|
||||||
@ -15,7 +15,7 @@ class CompilationCounter:
|
|||||||
# not including the splitting ops
|
# not including the splitting ops
|
||||||
num_piecewise_capturable_graphs_seen: int = 0
|
num_piecewise_capturable_graphs_seen: int = 0
|
||||||
num_backend_compilations: int = 0
|
num_backend_compilations: int = 0
|
||||||
num_cudagraph_caputured: int = 0
|
num_cudagraph_captured: int = 0
|
||||||
# InductorAdapter.compile calls
|
# InductorAdapter.compile calls
|
||||||
num_inductor_compiles: int = 0
|
num_inductor_compiles: int = 0
|
||||||
# EagerAdapter.compile calls
|
# EagerAdapter.compile calls
|
||||||
|
|||||||
@ -193,7 +193,7 @@ class CUDAPiecewiseBackend:
|
|||||||
entry.output = weak_ref_tensors(output)
|
entry.output = weak_ref_tensors(output)
|
||||||
entry.cudagraph = cudagraph
|
entry.cudagraph = cudagraph
|
||||||
|
|
||||||
compilation_counter.num_cudagraph_caputured += 1
|
compilation_counter.num_cudagraph_captured += 1
|
||||||
|
|
||||||
# important: we need to return the output, rather than
|
# important: we need to return the output, rather than
|
||||||
# the weak ref of the output, so that pytorch can correctly
|
# the weak ref of the output, so that pytorch can correctly
|
||||||
|
|||||||
@ -3918,12 +3918,14 @@ class CompilationConfig:
|
|||||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
||||||
|
|
||||||
# CudaGraph compilation
|
# CudaGraph compilation
|
||||||
use_cudagraph: bool = False
|
use_cudagraph: bool = envs.VLLM_USE_V1
|
||||||
"""Whether to use cudagraph inside compilation.
|
"""Whether to use cudagraph inside compilation.
|
||||||
- False: cudagraph inside compilation is not used.
|
- False: cudagraph inside compilation is not used.
|
||||||
- True: cudagraph inside compilation is used. It requires
|
- True: cudagraph inside compilation is used. It requires
|
||||||
that all input buffers have fixed addresses, and all
|
that all input buffers have fixed addresses, and all
|
||||||
splitting ops write their outputs to input buffers.
|
splitting ops write their outputs to input buffers.
|
||||||
|
In the vLLM V1 Engine, this flag only applies for
|
||||||
|
CompilationLevel.PIECEWISE (aka -O3).
|
||||||
Note that this is orthogonal to the cudagraph capture logic
|
Note that this is orthogonal to the cudagraph capture logic
|
||||||
outside of compilation.
|
outside of compilation.
|
||||||
TODO: move outside cudagraph logic into compilation.
|
TODO: move outside cudagraph logic into compilation.
|
||||||
@ -4425,7 +4427,6 @@ class VllmConfig:
|
|||||||
# FIXME(rob): Add function to set all of these.
|
# FIXME(rob): Add function to set all of these.
|
||||||
if not self.compilation_config.custom_ops:
|
if not self.compilation_config.custom_ops:
|
||||||
self.compilation_config.custom_ops = ["none"]
|
self.compilation_config.custom_ops = ["none"]
|
||||||
self.compilation_config.use_cudagraph = True
|
|
||||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||||
self.compilation_config.pass_config.enable_fusion = False
|
self.compilation_config.pass_config.enable_fusion = False
|
||||||
self.compilation_config.pass_config.enable_noop = False
|
self.compilation_config.pass_config.enable_noop = False
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user