mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 03:37:44 +08:00
[torch.compile] Support conditional torch.compile per module (#22269)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
3b11b26b50
commit
dfd2382039
@ -328,6 +328,7 @@ steps:
|
|||||||
- pytest -v -s compile/test_sequence_parallelism.py
|
- pytest -v -s compile/test_sequence_parallelism.py
|
||||||
- pytest -v -s compile/test_async_tp.py
|
- pytest -v -s compile/test_async_tp.py
|
||||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||||
|
- pytest -v -s compile/test_decorator.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
@ -341,6 +342,7 @@ steps:
|
|||||||
- pytest -v -s compile/piecewise/test_simple.py
|
- pytest -v -s compile/piecewise/test_simple.py
|
||||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||||
|
- pytest -v -s compile/piecewise/test_multiple_graphs.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Test # 18min
|
- label: PyTorch Fullgraph Test # 18min
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
|
|||||||
@ -12,10 +12,9 @@ from vllm.compilation.backends import set_model_tag
|
|||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||||
support_torch_compile)
|
support_torch_compile)
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||||
set_current_vllm_config)
|
VllmConfig, set_current_vllm_config)
|
||||||
from vllm.envs import VLLM_USE_V1
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.forward_context import set_forward_context
|
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
# create a library to hold the custom op
|
# create a library to hold the custom op
|
||||||
@ -164,103 +163,33 @@ class SimpleModelWithTwoGraphs(ParentModel):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def test_ignore_torch_compile_decorator():
|
|
||||||
assert VLLM_USE_V1
|
|
||||||
|
|
||||||
# piecewise
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
|
||||||
level=CompilationLevel.PIECEWISE,
|
|
||||||
use_cudagraph=True,
|
|
||||||
splitting_ops=["silly.attention"],
|
|
||||||
cudagraph_capture_sizes=[1, 2],
|
|
||||||
))
|
|
||||||
|
|
||||||
@support_torch_compile
|
|
||||||
class A(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x = x + x
|
|
||||||
attn_output = torch.empty_like(x)
|
|
||||||
torch.ops.silly.attention(x, x, x, attn_output)
|
|
||||||
x = attn_output
|
|
||||||
x = x * 3
|
|
||||||
return x
|
|
||||||
|
|
||||||
@ignore_torch_compile
|
|
||||||
class B(A):
|
|
||||||
...
|
|
||||||
|
|
||||||
@support_torch_compile
|
|
||||||
class C(B):
|
|
||||||
...
|
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
|
||||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
|
||||||
|
|
||||||
# A has support_torch_compile
|
|
||||||
with compilation_counter.expect(
|
|
||||||
num_graphs_seen=1,
|
|
||||||
num_piecewise_graphs_seen=3,
|
|
||||||
num_piecewise_capturable_graphs_seen=2,
|
|
||||||
num_backend_compilations=2,
|
|
||||||
num_cudagraph_captured=4,
|
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
|
||||||
), set_forward_context({}, vllm_config=vllm_config):
|
|
||||||
# first run is for compile
|
|
||||||
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
|
||||||
# run cudagraph captured sizes
|
|
||||||
mod_A(torch.randn(2, MLP_SIZE).cuda())
|
|
||||||
mod_A(torch.randn(1, MLP_SIZE).cuda())
|
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
|
||||||
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
|
|
||||||
|
|
||||||
# B's ignore_torch_compile should override A's support_torch_compile
|
|
||||||
with compilation_counter.expect(
|
|
||||||
num_graphs_seen=0,
|
|
||||||
num_piecewise_graphs_seen=0,
|
|
||||||
num_piecewise_capturable_graphs_seen=0,
|
|
||||||
num_backend_compilations=0,
|
|
||||||
num_cudagraph_captured=0,
|
|
||||||
), set_forward_context({}, vllm_config=vllm_config):
|
|
||||||
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
|
||||||
mod_B(torch.randn(2, MLP_SIZE).cuda())
|
|
||||||
mod_B(torch.randn(1, MLP_SIZE).cuda())
|
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
|
||||||
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
|
|
||||||
|
|
||||||
# C's support_torch_compile should override B's ignore_torch_compile
|
|
||||||
with compilation_counter.expect(
|
|
||||||
num_graphs_seen=1,
|
|
||||||
num_piecewise_graphs_seen=3,
|
|
||||||
num_piecewise_capturable_graphs_seen=2,
|
|
||||||
num_backend_compilations=2,
|
|
||||||
num_cudagraph_captured=4,
|
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
|
||||||
), set_forward_context({}, vllm_config=vllm_config):
|
|
||||||
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
|
||||||
mod_C(torch.randn(2, MLP_SIZE).cuda())
|
|
||||||
mod_C(torch.randn(1, MLP_SIZE).cuda())
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor):
|
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
|
||||||
|
cudagraph_runtime_mode: CUDAGraphMode):
|
||||||
with set_forward_context({}, vllm_config=vllm_config):
|
with set_forward_context({}, vllm_config=vllm_config):
|
||||||
# First run is for compile
|
# warmup for the model with cudagraph_mode NONE
|
||||||
model(inputs)
|
model(inputs)
|
||||||
|
|
||||||
# Run CUDAGraph captured sizes
|
# simulate cudagraphs capturing
|
||||||
|
with set_forward_context({},
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2, )):
|
||||||
model(inputs[:2])
|
model(inputs[:2])
|
||||||
|
with set_forward_context({},
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=1, )):
|
||||||
model(inputs[:1])
|
model(inputs[:1])
|
||||||
|
|
||||||
|
# simulate cudagraphs replay
|
||||||
|
with set_forward_context({},
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2, )):
|
||||||
output = model(inputs[:2])
|
output = model(inputs[:2])
|
||||||
|
|
||||||
output = output.cpu()
|
output = output.cpu()
|
||||||
@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly.attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
))
|
))
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||||
@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
num_cudagraph_captured=8,
|
num_cudagraph_captured=8,
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
outputs.append(run_model(vllm_config, model, inputs))
|
outputs.append(
|
||||||
|
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
|
|
||||||
# no compile or cudagraph
|
# no compile or cudagraph
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.NO_COMPILATION, ))
|
level=CompilationLevel.NO_COMPILATION, ))
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||||
@ -318,7 +250,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
num_backend_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_captured=0,
|
num_cudagraph_captured=0,
|
||||||
):
|
):
|
||||||
outputs.append(run_model(vllm_config, model, inputs))
|
outputs.append(
|
||||||
|
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
|
|
||||||
# piecewise compile without CUDA graph
|
# piecewise compile without CUDA graph
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
@ -326,6 +259,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
use_cudagraph=False,
|
use_cudagraph=False,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly.attention"],
|
||||||
))
|
))
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||||
@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
num_backend_compilations=4,
|
num_backend_compilations=4,
|
||||||
num_cudagraph_captured=0, # no cudagraph captured
|
num_cudagraph_captured=0, # no cudagraph captured
|
||||||
):
|
):
|
||||||
outputs.append(run_model(vllm_config, model, inputs))
|
outputs.append(
|
||||||
|
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
|
|
||||||
# Generally don't expect outputs with and without inductor
|
# Generally don't expect outputs with and without inductor
|
||||||
# to be bitwise equivalent
|
# to be bitwise equivalent
|
||||||
|
|||||||
251
tests/compile/test_decorator.py
Normal file
251
tests/compile/test_decorator.py
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.library import Library
|
||||||
|
|
||||||
|
from vllm.compilation.counter import compilation_counter
|
||||||
|
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||||
|
support_torch_compile)
|
||||||
|
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||||
|
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
||||||
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
# create a library to hold the custom op
|
||||||
|
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||||
|
|
||||||
|
BATCH_SIZE = 32
|
||||||
|
MLP_SIZE = 128
|
||||||
|
|
||||||
|
|
||||||
|
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: torch.Tensor) -> None:
|
||||||
|
out.copy_(q)
|
||||||
|
out += k
|
||||||
|
out += v
|
||||||
|
|
||||||
|
|
||||||
|
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: torch.Tensor) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="attention",
|
||||||
|
op_func=silly_attention,
|
||||||
|
mutates_args=["out"],
|
||||||
|
fake_impl=silly_attention_fake,
|
||||||
|
target_lib=silly_lib,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode
|
||||||
|
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
||||||
|
cudagraph_runtime_mode: CUDAGraphMode):
|
||||||
|
with set_forward_context({}, vllm_config=vllm_config):
|
||||||
|
# warmup for the model with cudagraph_mode NONE
|
||||||
|
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||||
|
|
||||||
|
# simulate cudagraphs capturing
|
||||||
|
with set_forward_context({},
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2, )):
|
||||||
|
model(torch.randn(2, MLP_SIZE).cuda())
|
||||||
|
with set_forward_context({},
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=1, )):
|
||||||
|
model(torch.randn(1, MLP_SIZE).cuda())
|
||||||
|
|
||||||
|
# simulate cudagraphs replay
|
||||||
|
with set_forward_context({},
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2, )):
|
||||||
|
output = model(torch.randn(2, MLP_SIZE).cuda())
|
||||||
|
|
||||||
|
output = output.cpu()
|
||||||
|
return output.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ignore_torch_compile_decorator():
|
||||||
|
# piecewise
|
||||||
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
use_cudagraph=True,
|
||||||
|
splitting_ops=["silly.attention"],
|
||||||
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
))
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class A(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = '',
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x + x
|
||||||
|
attn_output = torch.empty_like(x)
|
||||||
|
torch.ops.silly.attention(x, x, x, attn_output)
|
||||||
|
x = attn_output
|
||||||
|
x = x * 3
|
||||||
|
return x
|
||||||
|
|
||||||
|
@ignore_torch_compile
|
||||||
|
class B(A):
|
||||||
|
...
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class C(B):
|
||||||
|
...
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||||
|
|
||||||
|
# A has support_torch_compile
|
||||||
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=1,
|
||||||
|
num_piecewise_graphs_seen=3,
|
||||||
|
num_piecewise_capturable_graphs_seen=2,
|
||||||
|
num_backend_compilations=2,
|
||||||
|
num_cudagraph_captured=4,
|
||||||
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
|
):
|
||||||
|
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||||
|
|
||||||
|
# B's ignore_torch_compile should override A's support_torch_compile
|
||||||
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=0,
|
||||||
|
num_piecewise_graphs_seen=0,
|
||||||
|
num_piecewise_capturable_graphs_seen=0,
|
||||||
|
num_backend_compilations=0,
|
||||||
|
num_cudagraph_captured=0,
|
||||||
|
):
|
||||||
|
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||||
|
|
||||||
|
# C's support_torch_compile should override B's ignore_torch_compile
|
||||||
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=1,
|
||||||
|
num_piecewise_graphs_seen=3,
|
||||||
|
num_piecewise_capturable_graphs_seen=2,
|
||||||
|
num_backend_compilations=2,
|
||||||
|
num_cudagraph_captured=4,
|
||||||
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
|
):
|
||||||
|
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
||||||
|
|
||||||
|
|
||||||
|
# Only enable torch.compile if
|
||||||
|
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
||||||
|
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||||
|
kv_sharing_fast_prefill)
|
||||||
|
class B(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = '',
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x + x
|
||||||
|
attn_output = torch.empty_like(x)
|
||||||
|
torch.ops.silly.attention(x, x, x, attn_output)
|
||||||
|
x = attn_output
|
||||||
|
x = x + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Only enable torch.compile if
|
||||||
|
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
||||||
|
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||||
|
cache_config.kv_sharing_fast_prefill)
|
||||||
|
class A(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = '',
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
|
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.mod1(x)
|
||||||
|
attn_output = torch.empty_like(x)
|
||||||
|
torch.ops.silly.attention(x, x, x, attn_output)
|
||||||
|
x = attn_output
|
||||||
|
x = self.mod2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_conditional_compile_enable_if():
|
||||||
|
vllm_config = VllmConfig(cache_config=CacheConfig(
|
||||||
|
kv_sharing_fast_prefill=True, ),
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
use_cudagraph=True,
|
||||||
|
splitting_ops=["silly.attention"],
|
||||||
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
))
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||||
|
|
||||||
|
# A has support_torch_compile but enable_if fn returns False
|
||||||
|
# enalbe_if will be True for B, so we expect mod1 and mod2
|
||||||
|
# to be compiled
|
||||||
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=2,
|
||||||
|
num_piecewise_graphs_seen=6,
|
||||||
|
# 3 piecewise graphs per instance of B()
|
||||||
|
num_piecewise_capturable_graphs_seen=4,
|
||||||
|
num_backend_compilations=4,
|
||||||
|
num_cudagraph_captured=8,
|
||||||
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
|
):
|
||||||
|
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||||
|
|
||||||
|
# Set kv_sharing_fast_prefill=False
|
||||||
|
# which will cause A to be compiled and B to not be compiled
|
||||||
|
vllm_config = VllmConfig(cache_config=CacheConfig(
|
||||||
|
kv_sharing_fast_prefill=False, ),
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
use_cudagraph=True,
|
||||||
|
splitting_ops=["silly.attention"],
|
||||||
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
))
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||||
|
|
||||||
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=1,
|
||||||
|
num_piecewise_graphs_seen=7,
|
||||||
|
# 3 attn ops and 4 non-attn ops
|
||||||
|
num_piecewise_capturable_graphs_seen=4,
|
||||||
|
num_backend_compilations=4,
|
||||||
|
num_cudagraph_captured=8,
|
||||||
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
|
):
|
||||||
|
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||||
@ -52,6 +52,14 @@ def _should_ignore_torch_compile(cls) -> bool:
|
|||||||
return getattr(cls, IGNORE_COMPILE_KEY, False)
|
return getattr(cls, IGNORE_COMPILE_KEY, False)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def support_torch_compile(
|
||||||
|
*,
|
||||||
|
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
|
||||||
|
) -> Callable[[_T], _T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def support_torch_compile(
|
def support_torch_compile(
|
||||||
*,
|
*,
|
||||||
@ -69,6 +77,7 @@ def support_torch_compile(
|
|||||||
cls: Optional[_T] = None,
|
cls: Optional[_T] = None,
|
||||||
*,
|
*,
|
||||||
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
|
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
|
||||||
|
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
|
||||||
) -> Union[Callable[[_T], _T], _T]:
|
) -> Union[Callable[[_T], _T], _T]:
|
||||||
"""
|
"""
|
||||||
A decorator to add support for compiling the forward method of a class.
|
A decorator to add support for compiling the forward method of a class.
|
||||||
@ -118,6 +127,11 @@ def support_torch_compile(
|
|||||||
NOTE: if an argument is `None`, it should always be passed as `None` during
|
NOTE: if an argument is `None`, it should always be passed as `None` during
|
||||||
the lifetime of the model, otherwise, it cannot be captured as a single
|
the lifetime of the model, otherwise, it cannot be captured as a single
|
||||||
computation graph.
|
computation graph.
|
||||||
|
|
||||||
|
`enable_if` is a function that takes a `VllmConfig` object as input and
|
||||||
|
returns a boolean value indicating whether to compile the model or not.
|
||||||
|
This is useful if you want to compile the model only when certain
|
||||||
|
conditions are met.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def cls_decorator_helper(cls: _T) -> _T:
|
def cls_decorator_helper(cls: _T) -> _T:
|
||||||
@ -149,7 +163,8 @@ def support_torch_compile(
|
|||||||
if k not in sig.parameters:
|
if k not in sig.parameters:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Argument {k} not found in the forward method of {cls}")
|
f"Argument {k} not found in the forward method of {cls}")
|
||||||
return _support_torch_compile(cls, inferred_dynamic_arg_dims)
|
return _support_torch_compile(cls, inferred_dynamic_arg_dims,
|
||||||
|
enable_if)
|
||||||
|
|
||||||
if cls is not None:
|
if cls is not None:
|
||||||
# use `support_torch_compile` as a decorator without arguments
|
# use `support_torch_compile` as a decorator without arguments
|
||||||
@ -162,6 +177,7 @@ def support_torch_compile(
|
|||||||
def _support_torch_compile(
|
def _support_torch_compile(
|
||||||
cls: _T,
|
cls: _T,
|
||||||
dynamic_arg_dims: dict[str, Union[int, list[int]]],
|
dynamic_arg_dims: dict[str, Union[int, list[int]]],
|
||||||
|
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
|
||||||
) -> _T:
|
) -> _T:
|
||||||
"""
|
"""
|
||||||
A decorator to add support for compiling the forward method of a class.
|
A decorator to add support for compiling the forward method of a class.
|
||||||
@ -182,13 +198,14 @@ def _support_torch_compile(
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
||||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||||
# will handle the compilation, so we don't need to do anything here.
|
# will handle the compilation, so we don't need to do anything here.
|
||||||
self.do_not_compile = \
|
self.do_not_compile = \
|
||||||
vllm_config.compilation_config.level in [
|
vllm_config.compilation_config.level in [
|
||||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||||
] or not supports_dynamo() or _should_ignore_torch_compile(
|
] or not supports_dynamo() or _should_ignore_torch_compile(
|
||||||
self.__class__)
|
self.__class__) or not enable_compile
|
||||||
if self.do_not_compile:
|
if self.do_not_compile:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user