mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 19:35:01 +08:00
Add ability to use CUDAGraphs with use_inductor=False (#17345)
Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
parent
515b413ebf
commit
26b4fa45be
@ -74,11 +74,12 @@ class SillyModel(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def test_simple_piecewise_compile():
|
def _test_simple_piecewise_compile(*, use_inductor):
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
|
use_inductor=use_inductor,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly.attention"],
|
||||||
cudagraph_copy_inputs=True,
|
cudagraph_copy_inputs=True,
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
@ -108,3 +109,11 @@ def test_simple_piecewise_compile():
|
|||||||
output = model(input)
|
output = model(input)
|
||||||
assert global_counter == 2
|
assert global_counter == 2
|
||||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_piecewise_compile_inductor():
|
||||||
|
_test_simple_piecewise_compile(use_inductor=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_piecewise_compile_no_inductor():
|
||||||
|
_test_simple_piecewise_compile(use_inductor=False)
|
||||||
|
|||||||
@ -261,12 +261,14 @@ def tractable_computation(input_ids: torch.Tensor,
|
|||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(llama_config,
|
def run_model(llama_config,
|
||||||
use_compile: bool,
|
use_compile: bool,
|
||||||
|
use_inductor: bool,
|
||||||
split_attn: bool = False) -> torch.Tensor:
|
split_attn: bool = False) -> torch.Tensor:
|
||||||
|
|
||||||
if use_compile:
|
if use_compile:
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
|
use_inductor=use_inductor,
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
)
|
)
|
||||||
if split_attn:
|
if split_attn:
|
||||||
@ -304,7 +306,7 @@ def run_model(llama_config,
|
|||||||
return output.cpu()
|
return output.cpu()
|
||||||
|
|
||||||
|
|
||||||
def test_toy_llama():
|
def _test_toy_llama(*, use_inductor):
|
||||||
# compare output with and without piecewise compilation
|
# compare output with and without piecewise compilation
|
||||||
|
|
||||||
llama_config = LlamaConfig(hidden_size=128,
|
llama_config = LlamaConfig(hidden_size=128,
|
||||||
@ -326,8 +328,14 @@ def test_toy_llama():
|
|||||||
num_backend_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_caputured=0,
|
num_cudagraph_caputured=0,
|
||||||
):
|
):
|
||||||
outputs.append(run_model(llama_config, use_compile=False))
|
outputs.append(
|
||||||
run_model(tractable_config, use_compile=False)
|
run_model(llama_config, use_inductor=False, use_compile=False))
|
||||||
|
run_model(tractable_config, use_inductor=False, use_compile=False)
|
||||||
|
|
||||||
|
if use_inductor:
|
||||||
|
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
||||||
|
else:
|
||||||
|
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
@ -336,9 +344,13 @@ def test_toy_llama():
|
|||||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_caputured=
|
num_cudagraph_caputured=
|
||||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
outputs.append(run_model(llama_config, use_compile=True))
|
outputs.append(
|
||||||
run_model(tractable_config, use_compile=True)
|
run_model(llama_config,
|
||||||
|
use_inductor=use_inductor,
|
||||||
|
use_compile=True))
|
||||||
|
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
@ -353,13 +365,27 @@ def test_toy_llama():
|
|||||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
run_model(llama_config, use_compile=True, split_attn=True))
|
run_model(llama_config,
|
||||||
run_model(tractable_config, use_compile=True, split_attn=True)
|
use_inductor=use_inductor,
|
||||||
|
use_compile=True,
|
||||||
|
split_attn=True))
|
||||||
|
run_model(tractable_config,
|
||||||
|
use_inductor=use_inductor,
|
||||||
|
use_compile=True,
|
||||||
|
split_attn=True)
|
||||||
|
|
||||||
for i in range(1, len(outputs)):
|
for i in range(1, len(outputs)):
|
||||||
assert torch.allclose(outputs[0], outputs[i])
|
assert torch.allclose(outputs[0], outputs[i])
|
||||||
|
|
||||||
|
|
||||||
|
def test_toy_llama_inductor():
|
||||||
|
_test_toy_llama(use_inductor=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_toy_no_inductor():
|
||||||
|
_test_toy_llama(use_inductor=False)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def benchmark():
|
def benchmark():
|
||||||
from triton.testing import do_bench
|
from triton.testing import do_bench
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import torch._inductor.compile_fx
|
|||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.utils import is_torch_equal_or_newer
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
@ -175,6 +176,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
runtime_shape: Optional[int] = None,
|
runtime_shape: Optional[int] = None,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
compilation_counter.num_inductor_compiles += 1
|
||||||
current_config = {}
|
current_config = {}
|
||||||
if compiler_config is not None:
|
if compiler_config is not None:
|
||||||
current_config.update(compiler_config)
|
current_config.update(compiler_config)
|
||||||
@ -262,6 +264,7 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
runtime_shape: Optional[int] = None,
|
runtime_shape: Optional[int] = None,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
compilation_counter.num_inductor_compiles += 1
|
||||||
from torch._inductor.compile_fx import compile_fx
|
from torch._inductor.compile_fx import compile_fx
|
||||||
current_config = {}
|
current_config = {}
|
||||||
if compiler_config is not None:
|
if compiler_config is not None:
|
||||||
@ -528,6 +531,7 @@ class EagerAdaptor(CompilerInterface):
|
|||||||
runtime_shape: Optional[int] = None,
|
runtime_shape: Optional[int] = None,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
compilation_counter.num_eager_compiles += 1
|
||||||
# we don't need to compile the graph, just return the graph itself.
|
# we don't need to compile the graph, just return the graph itself.
|
||||||
# It does not support caching, return None for the handle.
|
# It does not support caching, return None for the handle.
|
||||||
return graph, None
|
return graph, None
|
||||||
|
|||||||
@ -15,6 +15,10 @@ class CompilationCounter:
|
|||||||
num_piecewise_capturable_graphs_seen: int = 0
|
num_piecewise_capturable_graphs_seen: int = 0
|
||||||
num_backend_compilations: int = 0
|
num_backend_compilations: int = 0
|
||||||
num_cudagraph_caputured: int = 0
|
num_cudagraph_caputured: int = 0
|
||||||
|
# InductorAdapter.compile calls
|
||||||
|
num_inductor_compiles: int = 0
|
||||||
|
# EagerAdapter.compile calls
|
||||||
|
num_eager_compiles: int = 0
|
||||||
|
|
||||||
def clone(self) -> "CompilationCounter":
|
def clone(self) -> "CompilationCounter":
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|||||||
@ -4315,15 +4315,10 @@ class VllmConfig:
|
|||||||
self.compilation_config.custom_ops.append("+rms_norm")
|
self.compilation_config.custom_ops.append("+rms_norm")
|
||||||
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
||||||
not self.model_config.enforce_eager:
|
not self.model_config.enforce_eager:
|
||||||
# NOTE(woosuk): Currently, we use inductor because the piecewise
|
|
||||||
# CUDA graphs do not work properly with the custom CUDA kernels.
|
|
||||||
# FIXME(woosuk): Disable inductor to reduce the compilation time
|
|
||||||
# and avoid any potential issues with the inductor.
|
|
||||||
# FIXME(rob): Add function to set all of these.
|
# FIXME(rob): Add function to set all of these.
|
||||||
if not self.compilation_config.custom_ops:
|
if not self.compilation_config.custom_ops:
|
||||||
self.compilation_config.custom_ops = ["none"]
|
self.compilation_config.custom_ops = ["none"]
|
||||||
self.compilation_config.use_cudagraph = True
|
self.compilation_config.use_cudagraph = True
|
||||||
self.compilation_config.use_inductor = True
|
|
||||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||||
self.compilation_config.pass_config.enable_fusion = False
|
self.compilation_config.pass_config.enable_fusion = False
|
||||||
self.compilation_config.pass_config.enable_noop = False
|
self.compilation_config.pass_config.enable_noop = False
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user