mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:55:57 +08:00
[torch.compile] Fix tests for torch==2.9 inductor partition (#26116)
Signed-off-by: ProExpertProg <lgovedic@redhat.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
parent
579d2e5458
commit
2dcd12d357
@ -11,6 +11,7 @@ from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@ -32,13 +33,13 @@ def temporary_environ(env_vars):
|
|||||||
os.environ[k] = v
|
os.environ[k] = v
|
||||||
|
|
||||||
|
|
||||||
test_params_full_cudagraph = []
|
model_backends_full_cudagraph = []
|
||||||
|
|
||||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||||
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
||||||
for mla_backend in MLA_backends:
|
for mla_backend in MLA_backends:
|
||||||
test_params_full_cudagraph.append(
|
model_backends_full_cudagraph.append(
|
||||||
pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))
|
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Qwen/Qwen2-1.5B-Instruct with other backends
|
# Qwen/Qwen2-1.5B-Instruct with other backends
|
||||||
@ -46,14 +47,18 @@ other_backend_configs = [
|
|||||||
backend_configs[c] for c in backend_configs if c not in MLA_backends
|
backend_configs[c] for c in backend_configs if c not in MLA_backends
|
||||||
]
|
]
|
||||||
for backend_config in other_backend_configs:
|
for backend_config in other_backend_configs:
|
||||||
test_params_full_cudagraph.append(
|
model_backends_full_cudagraph.append(("Qwen/Qwen2-1.5B-Instruct", backend_config))
|
||||||
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@pytest.fixture(scope="class")
|
||||||
def llm_pair(request):
|
def llm_pair(request):
|
||||||
model, backend_config = request.param
|
model, backend_config, use_inductor_graph_partition = request.param
|
||||||
|
backend_config.comp_config["use_inductor_graph_partition"] = (
|
||||||
|
use_inductor_graph_partition
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
|
pytest.skip("Inductor graph partition only supported in torch>=2.9")
|
||||||
|
|
||||||
# Dynamically skip test if GPU capability is not met
|
# Dynamically skip test if GPU capability is not met
|
||||||
if (
|
if (
|
||||||
@ -104,7 +109,15 @@ def llm_pair(request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
|
@pytest.mark.parametrize(
|
||||||
|
"llm_pair",
|
||||||
|
[
|
||||||
|
pytest.param((model, backend_config, use_inductor_graph_partition))
|
||||||
|
for model, backend_config in model_backends_full_cudagraph
|
||||||
|
for use_inductor_graph_partition in [True, False]
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
class TestFullCUDAGraph:
|
class TestFullCUDAGraph:
|
||||||
"""
|
"""
|
||||||
Use a class such that an llm pair is constructed once for all
|
Use a class such that an llm pair is constructed once for all
|
||||||
|
|||||||
@ -5,6 +5,7 @@ Test (piecewise) compilation with a simple model where multiple submodules
|
|||||||
are compiled and graph captured separately.
|
are compiled and graph captured separately.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -190,7 +191,12 @@ def run_model(
|
|||||||
return output.cpu()
|
return output.cpu()
|
||||||
|
|
||||||
|
|
||||||
def test_multi_graph_piecewise_compile_outputs_equal():
|
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
|
||||||
|
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||||
|
if use_inductor_graph_partition:
|
||||||
|
# FIXME(luka/boyuan): this currently fails
|
||||||
|
pytest.skip("Inductor graph partition not supported with multi-graph")
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
# piecewise compile
|
# piecewise compile
|
||||||
@ -200,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly::attention"],
|
splitting_ops=["silly::attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
@ -220,16 +227,24 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
# static tensor addresses
|
# static tensor addresses
|
||||||
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
|
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
|
||||||
|
|
||||||
with compilation_counter.expect(
|
if use_inductor_graph_partition:
|
||||||
num_graphs_seen=2, # two graphs for the model
|
# Splitting happens at Inductor lowering level,
|
||||||
num_piecewise_graphs_seen=6,
|
# total piecewise fx graphs is equal to total graphs
|
||||||
|
num_piecewise_fx = 2
|
||||||
|
num_piecewise_capturable_fx = 2
|
||||||
|
else:
|
||||||
# attn_one, attn_two each has 3 piecewise graphs
|
# attn_one, attn_two each has 3 piecewise graphs
|
||||||
# (pre attn, post attn, silly_attention) each
|
# (pre attn, post attn, silly_attention) each
|
||||||
num_piecewise_capturable_graphs_seen=4,
|
num_piecewise_fx = 6
|
||||||
# attn_one, attn_two has pre attn and post attn each, total=4
|
# attn_one, attn_two has pre attn and post attn each, total=4
|
||||||
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
|
num_piecewise_capturable_fx = 4
|
||||||
num_cudagraph_captured=8,
|
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=2, # two graphs for the model
|
||||||
|
num_piecewise_graphs_seen=num_piecewise_fx,
|
||||||
|
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
|
||||||
|
num_backend_compilations=num_piecewise_capturable_fx,
|
||||||
|
num_cudagraph_captured=8, # num_cudagraph_sizes * num_partitions
|
||||||
):
|
):
|
||||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
|
|
||||||
@ -268,6 +283,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=False,
|
use_cudagraph=False,
|
||||||
splitting_ops=["silly::attention"],
|
splitting_ops=["silly::attention"],
|
||||||
|
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
@ -286,9 +302,9 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=2,
|
num_graphs_seen=2,
|
||||||
num_piecewise_graphs_seen=6,
|
num_piecewise_graphs_seen=num_piecewise_fx,
|
||||||
num_piecewise_capturable_graphs_seen=4,
|
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
|
||||||
num_backend_compilations=4,
|
num_backend_compilations=num_piecewise_capturable_fx,
|
||||||
num_cudagraph_captured=0, # no cudagraph captured
|
num_cudagraph_captured=0, # no cudagraph captured
|
||||||
):
|
):
|
||||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
|
|||||||
@ -9,6 +9,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are
|
|||||||
initialized randomly with a fixed seed.
|
initialized randomly with a fixed seed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -26,6 +27,7 @@ from vllm.config import (
|
|||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
from .. import silly_attention # noqa: F401
|
from .. import silly_attention # noqa: F401
|
||||||
@ -257,27 +259,13 @@ def tractable_computation(
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(
|
def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
|
||||||
llama_config, use_compile: bool, backend: str, split_attn: bool = False
|
# Start with a fresh copy to make sure there's no cache dir sharing
|
||||||
) -> torch.Tensor:
|
compile_config = deepcopy(compile_config)
|
||||||
if use_compile:
|
cudagraph_runtime_mode = compile_config.cudagraph_mode
|
||||||
compilation_config = CompilationConfig(
|
|
||||||
level=CompilationLevel.PIECEWISE,
|
|
||||||
use_cudagraph=True,
|
|
||||||
backend=backend,
|
|
||||||
cudagraph_capture_sizes=[1, 2],
|
|
||||||
)
|
|
||||||
if split_attn:
|
|
||||||
compilation_config.splitting_ops = ["silly::attention"]
|
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
|
||||||
else:
|
|
||||||
compilation_config = CompilationConfig(
|
|
||||||
level=CompilationLevel.NO_COMPILATION,
|
|
||||||
)
|
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
compilation_config=compilation_config, additional_config=llama_config
|
compilation_config=compile_config, additional_config=llama_config
|
||||||
)
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = (
|
model = (
|
||||||
@ -338,8 +326,25 @@ def run_model(
|
|||||||
return output.cpu()
|
return output.cpu()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("backend", ["inductor", "eager"])
|
@pytest.mark.parametrize(
|
||||||
def test_toy_llama(backend: str):
|
"backend, use_inductor_graph_partition",
|
||||||
|
[
|
||||||
|
("eager", False), # No inductor
|
||||||
|
("inductor", False), # Inductor, Dynamo partition
|
||||||
|
("inductor", True), # Inductor, Inductor partition
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_toy_llama(
|
||||||
|
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
# We disable the vLLM compile cache into a new tmp dir for 2 reasons:
|
||||||
|
# 1. To make sure we can properly track the number of Inductor compilations.
|
||||||
|
# 2. Inductor partitioning does not play nicely with Autograd cache (below)
|
||||||
|
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||||
|
|
||||||
|
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
|
pytest.skip("Inductor graph partition only supported in torch>=2.9")
|
||||||
|
|
||||||
# compare output with and without piecewise compilation
|
# compare output with and without piecewise compilation
|
||||||
|
|
||||||
llama_config = LlamaConfig(
|
llama_config = LlamaConfig(
|
||||||
@ -350,6 +355,32 @@ def test_toy_llama(backend: str):
|
|||||||
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
|
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
compile_config_no_compile = CompilationConfig(
|
||||||
|
level=CompilationLevel.NO_COMPILATION,
|
||||||
|
cudagraph_mode=CUDAGraphMode.NONE,
|
||||||
|
backend="eager",
|
||||||
|
)
|
||||||
|
|
||||||
|
compile_config_no_split = CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||||
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
backend=backend,
|
||||||
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
)
|
||||||
|
|
||||||
|
# FIXME(luka/boyuan): the graph from the previous test case
|
||||||
|
# (no inductor partition) gets cached by AotAutograd so then the
|
||||||
|
# compilation with inductor partitioning incorrectly loads an unpartitioned
|
||||||
|
# graph and never partitions. I think this is a bug with custom inductor
|
||||||
|
# partitioning but does not affect vLLM more generally as vLLM uses its own
|
||||||
|
# cache (which takes inductor partitioning into account).
|
||||||
|
if use_inductor_graph_partition:
|
||||||
|
compile_config_no_split.inductor_compile_config["force_disable_caches"] = True
|
||||||
|
|
||||||
|
compile_config_split = deepcopy(compile_config_no_split)
|
||||||
|
compile_config_split.splitting_ops = ["silly::attention"]
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=0,
|
num_graphs_seen=0,
|
||||||
@ -358,8 +389,9 @@ def test_toy_llama(backend: str):
|
|||||||
num_backend_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_captured=0,
|
num_cudagraph_captured=0,
|
||||||
):
|
):
|
||||||
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
|
outputs.append(run_model(llama_config, compile_config_no_compile))
|
||||||
run_model(tractable_config, backend="eager", use_compile=False)
|
|
||||||
|
run_model(tractable_config, compile_config_no_compile)
|
||||||
|
|
||||||
if backend == "inductor":
|
if backend == "inductor":
|
||||||
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
||||||
@ -367,35 +399,34 @@ def test_toy_llama(backend: str):
|
|||||||
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
# One graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_graphs_seen=1,
|
|
||||||
num_piecewise_graphs_seen=1,
|
num_piecewise_graphs_seen=1,
|
||||||
num_piecewise_capturable_graphs_seen=1,
|
num_piecewise_capturable_graphs_seen=1,
|
||||||
# num_piecewise_capturable_graphs_seen
|
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||||
num_backend_compilations=1,
|
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
|
||||||
num_cudagraph_captured=2,
|
num_cudagraph_captured=2,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
|
outputs.append(run_model(llama_config, compile_config_no_split))
|
||||||
run_model(tractable_config, backend=backend, use_compile=True)
|
|
||||||
|
run_model(tractable_config, compile_config_no_split)
|
||||||
|
|
||||||
|
if use_inductor_graph_partition:
|
||||||
|
num_piecewise_fx = 1
|
||||||
|
num_piecewise_capturable_fx = 1
|
||||||
|
else:
|
||||||
|
num_piecewise_fx = 2 * llama_config.num_layers + 1
|
||||||
|
num_piecewise_capturable_fx = 1 + llama_config.num_layers
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1
|
num_piecewise_graphs_seen=num_piecewise_fx,
|
||||||
num_piecewise_capturable_graphs_seen=1
|
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
|
||||||
+ llama_config.num_layers, # 1 + num_layers
|
num_backend_compilations=num_piecewise_capturable_fx,
|
||||||
num_backend_compilations=1
|
# num_cudagraph_sizes * num_partitions
|
||||||
+ llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
num_cudagraph_captured=2 * (1 + llama_config.num_layers),
|
||||||
num_cudagraph_captured=2
|
|
||||||
* (
|
|
||||||
1 + llama_config.num_layers
|
|
||||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(run_model(llama_config, compile_config_split))
|
||||||
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
|
run_model(tractable_config, compile_config_split)
|
||||||
)
|
|
||||||
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
|
|
||||||
|
|
||||||
for i in range(1, len(outputs)):
|
for i in range(1, len(outputs)):
|
||||||
assert torch.allclose(outputs[0], outputs[i])
|
assert torch.allclose(outputs[0], outputs[i])
|
||||||
|
|||||||
@ -62,5 +62,4 @@ direct_register_custom_op(
|
|||||||
mutates_args=["out"],
|
mutates_args=["out"],
|
||||||
fake_impl=silly_attention_fake,
|
fake_impl=silly_attention_fake,
|
||||||
target_lib=silly_lib,
|
target_lib=silly_lib,
|
||||||
tags=(torch._C.Tag.cudagraph_unsafe,),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -73,6 +73,7 @@ def test_ignore_torch_compile_decorator():
|
|||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly::attention"],
|
splitting_ops=["silly::attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
use_inductor_graph_partition=False, # TODO test both?
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
@ -188,6 +189,7 @@ def test_conditional_compile_enable_if():
|
|||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly::attention"],
|
splitting_ops=["silly::attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
use_inductor_graph_partition=False, # TODO test both
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
@ -220,6 +222,7 @@ def test_conditional_compile_enable_if():
|
|||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly::attention"],
|
splitting_ops=["silly::attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
use_inductor_graph_partition=False, # TODO test both?
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -38,10 +38,6 @@ from vllm.utils import GiB_bytes, direct_register_custom_op
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
USE_XFORMERS_OPS = None
|
USE_XFORMERS_OPS = None
|
||||||
try:
|
|
||||||
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,)
|
|
||||||
except AttributeError:
|
|
||||||
tag_cudagraph_unsafe = () # type: ignore[assignment]
|
|
||||||
|
|
||||||
|
|
||||||
def check_xformers_availability():
|
def check_xformers_availability():
|
||||||
@ -879,7 +875,6 @@ direct_register_custom_op(
|
|||||||
op_name="unified_attention",
|
op_name="unified_attention",
|
||||||
op_func=unified_attention,
|
op_func=unified_attention,
|
||||||
fake_impl=unified_attention_fake,
|
fake_impl=unified_attention_fake,
|
||||||
tags=tag_cudagraph_unsafe,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -931,7 +926,6 @@ direct_register_custom_op(
|
|||||||
op_func=unified_attention_with_output,
|
op_func=unified_attention_with_output,
|
||||||
mutates_args=["output", "output_block_scale"],
|
mutates_args=["output", "output_block_scale"],
|
||||||
fake_impl=unified_attention_with_output_fake,
|
fake_impl=unified_attention_with_output_fake,
|
||||||
tags=tag_cudagraph_unsafe,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torch._library.utils import lookup_op
|
from torch._library.utils import lookup_op
|
||||||
@ -38,8 +39,16 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
|
|||||||
resolved.append(lookup_op(op_name))
|
resolved.append(lookup_op(op_name))
|
||||||
except Exception:
|
except Exception:
|
||||||
# Skip operators that don't exist (e.g., model-specific ops)
|
# Skip operators that don't exist (e.g., model-specific ops)
|
||||||
logger.warning(
|
# Do not warn for attention ops, warn for others
|
||||||
"Failed to resolve operator for Inductor partition: %s", op_name
|
# (most likely manually specified)
|
||||||
|
from vllm.config import CompilationConfig
|
||||||
|
|
||||||
|
logger.log(
|
||||||
|
logging.DEBUG
|
||||||
|
if op_name in CompilationConfig._attention_ops
|
||||||
|
else logging.WARNING,
|
||||||
|
"Failed to resolve operator for CUDAGraph partition: %s",
|
||||||
|
op_name,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -201,7 +201,7 @@ class CompilationConfig:
|
|||||||
(it sees a part of the graph). The backend can not be custom for compilation
|
(it sees a part of the graph). The backend can not be custom for compilation
|
||||||
level 3, i.e. the backend must be either eager or inductor. Furthermore,
|
level 3, i.e. the backend must be either eager or inductor. Furthermore,
|
||||||
compilation is only piecewise if splitting ops is set accordingly and
|
compilation is only piecewise if splitting ops is set accordingly and
|
||||||
use_inductor_cudagraphs_partition is off. Note that the default options for
|
use_inductor_graph_partition is off. Note that the default options for
|
||||||
splitting ops are sufficient for piecewise compilation.
|
splitting ops are sufficient for piecewise compilation.
|
||||||
"""
|
"""
|
||||||
custom_ops: list[str] = field(default_factory=list)
|
custom_ops: list[str] = field(default_factory=list)
|
||||||
@ -431,6 +431,7 @@ class CompilationConfig:
|
|||||||
factors.append(self.custom_ops)
|
factors.append(self.custom_ops)
|
||||||
factors.append(self.splitting_ops)
|
factors.append(self.splitting_ops)
|
||||||
factors.append(self.use_inductor)
|
factors.append(self.use_inductor)
|
||||||
|
factors.append(self.use_inductor_graph_partition)
|
||||||
factors.append(self.inductor_compile_config)
|
factors.append(self.inductor_compile_config)
|
||||||
factors.append(self.inductor_passes)
|
factors.append(self.inductor_passes)
|
||||||
factors.append(self.pass_config.uuid())
|
factors.append(self.pass_config.uuid())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user