[V1] address post issues related to #20059 (part 1) (#23046)

Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
fhl2000 2025-09-27 03:58:19 +08:00 committed by GitHub
parent f708bd4904
commit f075693da7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 348 additions and 292 deletions

View File

@ -3,12 +3,11 @@
import contextlib import contextlib
import os import os
import weakref import weakref
from dataclasses import dataclass
from typing import Optional
import pytest import pytest
from tests.utils import wait_for_gpu_memory_to_clear from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig from vllm.config import CompilationConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -33,89 +32,6 @@ def temporary_environ(env_vars):
os.environ[k] = v os.environ[k] = v
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
},
specific_gpu_arch=(10, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={
"cudagraph_mode": "FULL",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
test_params_full_cudagraph = [] test_params_full_cudagraph = []
# deepseek-ai/DeepSeek-V2-Lite with MLA # deepseek-ai/DeepSeek-V2-Lite with MLA

View File

@ -4,7 +4,7 @@ import pytest
import vllm import vllm
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationConfig, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.utils import _is_torch_equal_or_newer from vllm.utils import _is_torch_equal_or_newer
@ -106,7 +106,6 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
def test_no_compilation(vllm_runner, monkeypatch): def test_no_compilation(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
with ( with (
compilation_counter.expect(num_graphs_seen=0, compilation_counter.expect(num_graphs_seen=0,
dynamo_as_is_count=0), dynamo_as_is_count=0),
@ -131,3 +130,67 @@ def test_enforce_eager(vllm_runner, monkeypatch):
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.4) as _): gpu_memory_utilization=0.4) as _):
pass pass
def test_splitting_ops_dynamic():
# Default config
config = VllmConfig()
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.FULL_AND_PIECEWISE
assert config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
if _is_torch_equal_or_newer('2.9.0.dev'):
# inductor graph partition is only available in PyTorch 2.9+.
# this is a fast config check so we are not using pytest.skip.
config = VllmConfig(compilation_config=CompilationConfig(
use_inductor_graph_partition=True,
splitting_ops=["silly_attention"]))
# should ignore splitting_ops
assert config.compilation_config.splitting_ops == []
# When attn_fusion pass enabled.
config = VllmConfig(compilation_config=CompilationConfig(
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
))
assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.FULL
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(AssertionError):
config = VllmConfig(compilation_config=CompilationConfig(
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
))
# When both use_inductor_graph_partition and attn_fusion pass enabled.
if _is_torch_equal_or_newer('2.9.0.dev'):
config = VllmConfig(compilation_config=CompilationConfig(
use_inductor_graph_partition=True,
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
))
assert config.compilation_config.splitting_ops == []
# enable_attn_fusion is directly support under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.PIECEWISE

View File

@ -3,7 +3,7 @@
"""Utility functions for attention-related v1 tests.""" """Utility functions for attention-related v1 tests."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Optional, Union
import pytest import pytest
import torch import torch
@ -260,3 +260,88 @@ def create_dummy_kv_cache(block_size: int,
dtype=dtype, dtype=dtype,
device=device) device=device)
return kv_cache return kv_cache
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict # compilation config
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
full_cg_backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(10, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}

View File

@ -45,38 +45,21 @@ def _create_vllm_config(compilation_config: CompilationConfig,
class TestCudagraphDispatcher: class TestCudagraphDispatcher:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"params", "case_id,cudagraph_mode_str,compilation_level",
[ [
# Test case 0: Full CG for mixed batches, no separate routine # Test case 0: Full CG for mixed batches, no separate routine
{ (0, "FULL", CompilationLevel.NO_COMPILATION),
"case_id": 0,
"cudagraph_mode": "FULL",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
# Test case 1: Full CG for uniform batches, piecewise for mixed # Test case 1: Full CG for uniform batches, piecewise for mixed
{ (1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION),
"case_id": 1,
"cudagraph_mode": "FULL_AND_PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
# Test case 2: Full CG for uniform batches, no CG for mixed # Test case 2: Full CG for uniform batches, no CG for mixed
{ (2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
"case_id": 2,
"cudagraph_mode": "FULL_DECODE_ONLY",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
# Test case 3: Piecewise for all # Test case 3: Piecewise for all
{ (3, "PIECEWISE", CompilationLevel.PIECEWISE),
"case_id": 3,
"cudagraph_mode": "PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
]) ])
def test_dispatcher(self, params): def test_dispatcher(self, cudagraph_mode_str, compilation_level):
# Setup dispatcher # Setup dispatcher
comp_config = CompilationConfig( comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str,
cudagraph_mode=params["cudagraph_mode"], level=compilation_level,
level=params["compilation_level"],
cudagraph_capture_sizes=[1, 8]) cudagraph_capture_sizes=[1, 8])
config = _create_vllm_config(comp_config, max_num_seqs=8) config = _create_vllm_config(comp_config, max_num_seqs=8)
@ -86,11 +69,11 @@ class TestCudagraphDispatcher:
uniform_decode_query_len=1) uniform_decode_query_len=1)
# Verify the key is initialized correctly # Verify the key is initialized correctly
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2 assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
else: else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0 assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]: if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2 assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
else: else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0 assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
@ -99,10 +82,10 @@ class TestCudagraphDispatcher:
# 1. non-uniform batch, size in cudagraph size list # 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact) rt_mode, key = dispatcher.dispatch(desc_full_exact)
if params["cudagraph_mode"] == "FULL": if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact assert key == desc_full_exact
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact assert key == desc_full_exact
else: else:
@ -111,15 +94,13 @@ class TestCudagraphDispatcher:
# 2. uniform decode batch, size in cudagraph size list # 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True) desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
rt_mode, key = dispatcher.dispatch(desc_uniform_exact) rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
if params["cudagraph_mode"] == "FULL": if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.non_uniform assert key == desc_uniform_exact.non_uniform
elif params["cudagraph_mode"] in [ elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
]:
assert rt_mode == CUDAGraphMode.FULL assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact assert key == desc_uniform_exact
elif params["cudagraph_mode"] == "PIECEWISE": elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.non_uniform assert key == desc_uniform_exact.non_uniform
else: else:
@ -131,6 +112,16 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
assert key is None assert key is None
# 4. Cascade attention should have a fall back mode
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact,
use_cascade_attn=True)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.non_uniform
else:
assert rt_mode == CUDAGraphMode.NONE
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper: class TestCUDAGraphWrapper:

View File

@ -4,12 +4,11 @@ import contextlib
import os import os
import weakref import weakref
from contextlib import ExitStack from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional
import pytest import pytest
from tests.utils import wait_for_gpu_memory_to_clear from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM from vllm import LLM
from vllm.config import CompilationConfig from vllm.config import CompilationConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -34,74 +33,6 @@ def temporary_environ(env_vars):
os.environ[k] = v os.environ[k] = v
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
# test attention backend and cudagraph_mode combo # test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported) # (backend_name, cudagraph_mode, supported)
combo_cases_1 = [ combo_cases_1 = [
@ -114,9 +45,10 @@ combo_cases_1 = [
] ]
@pytest.mark.parametrize("combo_case", combo_cases_1) @pytest.mark.parametrize("backend_name, cudagraph_mode, supported",
def test_backend_and_cudagraph_mode_combo(combo_case): combo_cases_1)
backend_name, cudagraph_mode, supported = combo_case def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode,
supported):
if backend_name == "FlashInfer": if backend_name == "FlashInfer":
try: try:
import flashinfer # noqa: F401 import flashinfer # noqa: F401
@ -142,7 +74,7 @@ def test_backend_and_cudagraph_mode_combo(combo_case):
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
level=3, cudagraph_mode=cudagraph_mode)) level=3, cudagraph_mode=cudagraph_mode))
llm.generate(["Hello, my name is"] * 10) llm.generate(["Hello, my name is"] * 10)
# when above code raises, `llm` may be undefined, so we need to catch that
try: try:
llm = weakref.proxy(llm) llm = weakref.proxy(llm)
del llm del llm
@ -173,7 +105,8 @@ combo_cases_2 = [
] ]
@pytest.mark.parametrize("combo_case", combo_cases_2) @pytest.mark.parametrize("backend_name,cudagraph_mode,compilation_level,"\
"supported", combo_cases_2)
def test_cudagraph_compilation_combo(combo_case): def test_cudagraph_compilation_combo(combo_case):
backend_name, cudagraph_mode, compilation_level, supported\ backend_name, cudagraph_mode, compilation_level, supported\
= combo_case = combo_case
@ -192,6 +125,7 @@ def test_cudagraph_compilation_combo(combo_case):
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
level=compilation_level, cudagraph_mode=cudagraph_mode)) level=compilation_level, cudagraph_mode=cudagraph_mode))
llm.generate(["Hello, my name is"] * 10) llm.generate(["Hello, my name is"] * 10)
# when above code raises, `llm` may be undefined, so we need to catch that
try: try:
llm = weakref.proxy(llm) llm = weakref.proxy(llm)
del llm del llm

View File

@ -340,15 +340,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
num_graphs=len(self.compile_submod_names), num_graphs=len(self.compile_submod_names),
runtime_shape=None) runtime_shape=None)
# Lazy import here to avoid circular import # Lazy import here to avoid circular import
from .cuda_piecewise_backend import PiecewiseBackend from .piecewise_backend import PiecewiseBackend
piecewise_backend = PiecewiseBackend( piecewise_backend = PiecewiseBackend(
submod, self.vllm_config, index, submod, self.vllm_config, index,
len(self.compile_submod_names), sym_shape_indices, len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_dynamic_shape, self.vllm_backend) compiled_graph_for_dynamic_shape, self.vllm_backend)
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE if (self.compilation_config.cudagraph_mode.\
and has_piecewise_cudagraphs() and
not self.compilation_config.use_inductor_graph_partition): not self.compilation_config.use_inductor_graph_partition):
# We're using Dynamo-based piecewise splitting, so we wrap # We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper. # the whole subgraph with a static graph wrapper.

View File

@ -336,7 +336,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
from vllm.config import CUDAGraphMode from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
and compilation_config.use_inductor_graph_partition): and compilation_config.use_inductor_graph_partition):
from torch._inductor.utils import CUDAGraphWrapperMetadata from torch._inductor.utils import CUDAGraphWrapperMetadata
@ -365,7 +365,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
yield yield
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
and compilation_config.use_inductor_graph_partition): and compilation_config.use_inductor_graph_partition):
torch._inductor.utils.set_customized_partition_wrappers(None) torch._inductor.utils.set_customized_partition_wrappers(None)

View File

@ -459,15 +459,22 @@ class VllmConfig:
"to True to enable.") "to True to enable.")
current_platform.check_and_update_config(self) current_platform.check_and_update_config(self)
# final check of cudagraph mode after platform-specific update # Do this after all the updates to compilation_config.level
if envs.VLLM_USE_V1 and \
self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
# final check of cudagraph mode after all possible updates
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \ if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\
and self.model_config is not None and \ and self.model_config is not None and \
not self.model_config.disable_cascade_attn: not self.model_config.disable_cascade_attn and\
logger.info("CUDAGraphMode.FULL is not supported with " not self.compilation_config.cudagraph_mode.\
"cascade attention currently. Disabling cascade" has_piecewise_cudagraphs():
"attention.") logger.warning_once(
self.model_config.disable_cascade_attn = True "No piecewise cudagraph for executing cascade attention."
" Will fall back to eager execution if a batch runs "
"into cascade attentions")
if self.compilation_config.cudagraph_mode\ if self.compilation_config.cudagraph_mode\
.requires_piecewise_compilation(): .requires_piecewise_compilation():
@ -477,6 +484,12 @@ class VllmConfig:
"when cudagraph_mode piecewise cudagraphs is used, "\ "when cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.compilation_config.cudagraph_mode}" f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
# final migrate the deprecated flags
self.compilation_config.use_cudagraph = self.compilation_config.\
cudagraph_mode!= CUDAGraphMode.NONE
self.compilation_config.full_cuda_graph = self.compilation_config.\
cudagraph_mode.has_full_cudagraphs()
if self.parallel_config.enable_dbo: if self.parallel_config.enable_dbo:
a2a_backend = envs.VLLM_ALL2ALL_BACKEND a2a_backend = envs.VLLM_ALL2ALL_BACKEND
assert a2a_backend in \ assert a2a_backend in \
@ -487,14 +500,14 @@ class VllmConfig:
"variable to deepep_low_latency or deepep_high_throughput and "\ "variable to deepep_low_latency or deepep_high_throughput and "\
"install the DeepEP kernels." "install the DeepEP kernels."
if not self.model_config.disable_cascade_attn:
self.model_config.disable_cascade_attn = True
logger.warning_once(
"Disabling cascade attention when DBO is enabled.")
if not self.instance_id: if not self.instance_id:
self.instance_id = random_uuid()[:5] self.instance_id = random_uuid()[:5]
# Do this after all the updates to compilation_config.level
if envs.VLLM_USE_V1 and \
self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
if (envs.VLLM_USE_V1 if (envs.VLLM_USE_V1
and not self.scheduler_config.disable_hybrid_kv_cache_manager): and not self.scheduler_config.disable_hybrid_kv_cache_manager):
# logger should only print warning message for hybrid models. As we # logger should only print warning message for hybrid models. As we

View File

@ -61,9 +61,17 @@ class CUDAGraphMode(enum.Enum):
def has_full_cudagraphs(self) -> bool: def has_full_cudagraphs(self) -> bool:
return self.max_cudagraph_mode() == CUDAGraphMode.FULL return self.max_cudagraph_mode() == CUDAGraphMode.FULL
def has_piecewise_cudagraphs(self) -> bool:
return self.requires_piecewise_compilation()
def separate_routine(self) -> bool: def separate_routine(self) -> bool:
return isinstance(self.value, tuple) return isinstance(self.value, tuple)
def valid_runtime_modes(self) -> bool:
return self in [
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
]
@config @config
@dataclass @dataclass
@ -269,7 +277,8 @@ class CompilationConfig:
Note that this is orthogonal to the cudagraph capture logic Note that this is orthogonal to the cudagraph capture logic
outside of compilation. outside of compilation.
Warning: This flag is deprecated and will be removed in the next major or Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE
instead.
""" """
cudagraph_num_of_warmups: int = 0 cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph. """Number of warmup runs for cudagraph.
@ -294,7 +303,8 @@ class CompilationConfig:
flag cannot be used together with splitting_ops. This may provide flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models. performance benefits for smaller models.
Warning: This flag is deprecated and will be removed in the next major or Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
FULL_AND_PIECEWISE instead.
""" """
use_inductor_graph_partition: bool = False use_inductor_graph_partition: bool = False
@ -464,7 +474,8 @@ class CompilationConfig:
if not self.use_cudagraph: if not self.use_cudagraph:
logger.warning("use_cudagraph is deprecated, use " logger.warning("use_cudagraph is deprecated, use "
"cudagraph_mode=NONE instead.") "cudagraph_mode=NONE instead.")
if self.cudagraph_mode is not None: if self.cudagraph_mode is not None and \
self.cudagraph_mode != CUDAGraphMode.NONE:
raise ValueError( raise ValueError(
"use_cudagraph and cudagraph_mode are mutually" "use_cudagraph and cudagraph_mode are mutually"
" exclusive, prefer cudagraph_mode since " " exclusive, prefer cudagraph_mode since "
@ -473,7 +484,8 @@ class CompilationConfig:
if self.full_cuda_graph: if self.full_cuda_graph:
logger.warning("full_cuda_graph is deprecated, use " logger.warning("full_cuda_graph is deprecated, use "
"cudagraph_mode=FULL instead.") "cudagraph_mode=FULL instead.")
if self.cudagraph_mode is not None: if self.cudagraph_mode is not None and \
not self.cudagraph_mode.has_full_cudagraphs():
raise ValueError("full_cuda_graph and cudagraph_mode are " raise ValueError("full_cuda_graph and cudagraph_mode are "
"mutually exclusive, prefer cudagraph_mode " "mutually exclusive, prefer cudagraph_mode "
"since full_cuda_graph is deprecated.") "since full_cuda_graph is deprecated.")
@ -570,49 +582,76 @@ class CompilationConfig:
"set_splitting_ops_for_v1 should only be called when " "set_splitting_ops_for_v1 should only be called when "
"level is CompilationLevel.PIECEWISE") "level is CompilationLevel.PIECEWISE")
use_inductor_graph_partition_msg = ( if self.use_inductor_graph_partition:
"When use_inductor_graph_partition=True, splitting_ops " self.set_splitting_ops_for_inductor_graph_partition()
"are ignored and set to an empty list. Instead, " return
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
"used to annotate custom ops for graph partition.") if self.pass_config.enable_attn_fusion:
# here use_inductor_graph_partition is False
self.set_splitting_ops_for_attn_fusion()
return
if self.splitting_ops is None: if self.splitting_ops is None:
if self.use_inductor_graph_partition:
# When using inductor graph partition, we set splitting_ops
# to be empty and rely on torch._C.Tag.cudagraph_unsafe to
# annotate custom ops as splitting ops.
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
else:
# NOTE: When using full cudagraph, instead of setting an empty # NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx # list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture # graph, we keep the piecewise fx graph structure but capture
# the full cudagraph outside the fx graph. This reduces some # the full cudagraph outside the fx graph. This reduces some
# cpu overhead when the runtime batch_size is not cudagraph # cpu overhead when the runtime batch_size is not cudagraph
# captured. see https://github.com/vllm-project/vllm/pull/20059 # captured. see https://github.com/vllm-project/vllm/pull/20059
# for details. make a copy to avoid mutating the class-level # for details. Make a copy to avoid mutating the class-level
# list via reference. # list via reference.
self.splitting_ops = list(self._attention_ops) self.splitting_ops = list(self._attention_ops)
elif len(self.splitting_ops) == 0: elif len(self.splitting_ops) == 0:
logger.warning_once( logger.warning_once(
"Using piecewise compilation with empty " "Using piecewise compilation with empty splitting_ops")
"splitting_ops and use_inductor_graph_partition" if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
f"={self.use_inductor_graph_partition}.")
if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE
and not self.use_inductor_graph_partition):
logger.warning_once( logger.warning_once(
"When compilation level is piecewise with empty " "Piecewise compilation with empty splitting_ops do not" \
"splitting_ops, PIECEWISE cudagraph_mode will be " "contains piecewise cudagraph. Setting cudagraph_"
"treated as FULL cudagraph_mode. Please ensure you are " "mode to NONE. Hint: If you are using attention backends "
"using attention backends that support cudagraph or set " "that support cudagraph, consider manually setting "
"cudagraph_mode to NONE explicitly if encountering " "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
"any problems.") "full cudagraphs.")
self.cudagraph_mode = CUDAGraphMode.NONE
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not "
"contains piecewise cudagraph. Setting cudagraph_mode "
"to FULL.")
self.cudagraph_mode = CUDAGraphMode.FULL self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = [] self.splitting_ops = []
elif self.use_inductor_graph_partition:
def set_splitting_ops_for_inductor_graph_partition(self):
assert self.use_inductor_graph_partition
use_inductor_graph_partition_msg = (
"When use_inductor_graph_partition=True, splitting_ops "
"are ignored and set to an empty list. Instead, "
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
"used to annotate custom ops for graph partition.")
if self.splitting_ops is not None and \
len(self.splitting_ops) > 0:
logger.warning_once(use_inductor_graph_partition_msg) logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = [] self.splitting_ops = []
def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.enable_attn_fusion
if self.splitting_ops is None:
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off."
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
"explicitly if encountering any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
assert not self.splitting_ops_contain_attention(), (
"attention ops should not be in splitting_ops "
"when enable_attn_fusion is True")
def splitting_ops_contain_attention(self) -> bool: def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all( return self.splitting_ops is not None and all(
op in self.splitting_ops for op in self._attention_ops) op in self.splitting_ops for op in self._attention_ops)

View File

@ -246,8 +246,7 @@ class ForwardContext:
ubatch_slices: Optional[UBatchSlices] = None ubatch_slices: Optional[UBatchSlices] = None
def __post_init__(self): def __post_init__(self):
assert self.cudagraph_runtime_mode in [ assert self.cudagraph_runtime_mode.valid_runtime_modes(), \
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"

View File

@ -22,10 +22,10 @@ class CudagraphDispatcher:
At runtime, the dispatch method generates the runtime cudagraph mode (FULL, At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (communicate via forward context), based on the input key. After dispatching (communicated via forward
the cudagraph wrappers will trust the dispatch key to do either capturing context), the cudagraph wrappers will trust the dispatch key to either
or replaying (if mode matched), or pass through to the underlying runnable capture or replay (if the mode matches), or pass through to the underlying
without cudagraph (if mode no match or mode is NONE). runnable without cudagraph (if the mode does not match or mode is NONE).
""" """
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
@ -57,19 +57,15 @@ class CudagraphDispatcher:
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode, def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
batch_descriptor: BatchDescriptor): batch_descriptor: BatchDescriptor):
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {runtime_mode}" f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
self.cudagraph_keys[runtime_mode].add(batch_descriptor) self.cudagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int): uniform_decode_query_len: int):
# This should be called only after attention backend is initialized. # This should be called only after attention backend is initialized.
# Note: we create all valid keys possible for cudagraph but do not # Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, we create keys for # guarantee all keys would be used. For example, if we allow lazy
# piecewise cudagraphs when it is piecewise compilation, which is always
# valid, but for attention backend support unified routine, we may not
# trigger capturing/replaying the piecewise cudagraphs depending on
# CompilationConfig.cudagraph_mode. In addition, if we allow lazy
# capturing in future PR, some keys may never be triggered. # capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs in self.compilation_config.cudagraph_capture_sizes: for bs in self.compilation_config.cudagraph_capture_sizes:
@ -94,10 +90,13 @@ class CudagraphDispatcher:
self.keys_initialized = True self.keys_initialized = True
def dispatch( def dispatch(
self, batch_descriptor: BatchDescriptor self,
batch_descriptor: BatchDescriptor,
use_cascade_attn: bool = False
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
""" """
Given a batch descriptor, dispatch to a cudagraph mode. Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform). to a graph that supports a more general batch (uniform to non-uniform).
""" """
@ -107,12 +106,14 @@ class CudagraphDispatcher:
"initialized. No cudagraph will be used.") "initialized. No cudagraph will be used.")
return CUDAGraphMode.NONE, None return CUDAGraphMode.NONE, None
non_uniform_key = batch_descriptor.non_uniform
# if a batch use cascade attention, bypass checking full cudagraphs
if not use_cascade_attn:
# check if key exists for full cudagraph # check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor return CUDAGraphMode.FULL, batch_descriptor
# otherwise, check if non-uniform key exists # otherwise, check if non-uniform key exists
non_uniform_key = batch_descriptor.non_uniform
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key return CUDAGraphMode.FULL, non_uniform_key

View File

@ -923,11 +923,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> tuple[PerLayerAttnMetadata, torch.Tensor, ) -> tuple[PerLayerAttnMetadata, torch.Tensor,
Optional[SpecDecodeMetadata], np.ndarray, Optional[SpecDecodeMetadata], np.ndarray,
Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], Optional[CommonAttentionMetadata], int, Optional[UBatchSlices],
Optional[torch.Tensor]]: Optional[torch.Tensor], bool]:
""" """
:return: tuple[ :return: tuple[
attn_metadata: layer-to-attention_metadata mapping, attn_metadata: layer-to-attention_metadata mapping,
logits_indices, spec_decode_metadata logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens, use_cascade_attn
] ]
""" """
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -1135,6 +1137,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata: PerLayerAttnMetadata = {} attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None: if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))] attn_metadata = [dict() for _ in range(len(ubatch_slices))]
use_cascade_attn = False
# Used in the below loop. # Used in the below loop.
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
@ -1251,9 +1254,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args) **extra_attn_metadata_args)
use_cascade_attn |= getattr(attn_metadata_i, "use_cascade",
False)
for layer_name in attn_group.layer_names: for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
# disable cascade attention when DBO
if ubatch_slices is not None:
use_cascade_attn = False
# Hot-Swap lora model # Hot-Swap lora model
if self.lora_config: if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens) self.set_active_loras(self.input_batch, num_scheduled_tokens)
@ -1261,7 +1270,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return (attn_metadata, logits_indices, spec_decode_metadata, return (attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens, ubatch_slices, max_num_scheduled_tokens, ubatch_slices,
num_tokens_after_padding) num_tokens_after_padding, use_cascade_attn)
def _compute_cascade_attn_prefix_len( def _compute_cascade_attn_prefix_len(
self, self,
@ -2251,8 +2260,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare the decoder inputs. # Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata, (attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len, ubatch_slices, num_tokens_after_padding max_query_len, ubatch_slices, num_tokens_after_padding,
) = self._prepare_inputs(scheduler_output) use_cascade_attn) = self._prepare_inputs(scheduler_output)
( (
num_scheduled_tokens, num_scheduled_tokens,
@ -2273,7 +2282,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode) uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \ cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor) self.cudagraph_dispatcher.dispatch(batch_descriptor,
use_cascade_attn)
# This is currently to get around the assert in the DPMetadata # This is currently to get around the assert in the DPMetadata
# where it wants `num_tokens_across_dp` to align with `num_tokens` # where it wants `num_tokens_across_dp` to align with `num_tokens`
@ -2701,16 +2711,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"Cannot reload weights before model is loaded." "Cannot reload weights before model is loaded."
model_loader = get_model_loader(self.load_config) model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...") logger.info("Reloading weights inplace...")
model = self.get_model() model_loader.load_weights(self.get_model(),
model_loader.load_weights(model, model_config=self.model_config) model_config=self.model_config)
def save_tensorized_model( def save_tensorized_model(
self, self,
tensorizer_config: "TensorizerConfig", tensorizer_config: "TensorizerConfig",
) -> None: ) -> None:
model = self.get_model()
TensorizerLoader.save_model( TensorizerLoader.save_model(
model, self.get_model(),
tensorizer_config=tensorizer_config, tensorizer_config=tensorizer_config,
model_config=self.model_config, model_config=self.model_config,
) )
@ -2926,9 +2935,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
(1 token) and prefill (multiple tokens) requests. (1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run remove_lora: If False, dummy LoRAs are not destroyed after the run
""" """
assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { assert cudagraph_runtime_mode is None or \
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL cudagraph_runtime_mode.valid_runtime_modes()
}
# If cudagraph_mode.decode_mode() == FULL and # If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using # cudagraph_mode.separate_routine(). This means that we are using
@ -3113,7 +3121,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# filter out the valid batch descriptor # filter out the valid batch descriptor
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens, BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode)) uniform_decode=uniform_decode)) \
if not is_profile else (CUDAGraphMode.NONE, None)
if cudagraph_runtime_mode is not None: if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support # we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture # warm ups for cudagraph capture
@ -3453,8 +3462,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode: CUDAGraphMode, cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool): uniform_decode: bool):
assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
cudagraph_runtime_mode in [CUDAGraphMode.FULL, cudagraph_runtime_mode.valid_runtime_modes(), \
CUDAGraphMode.PIECEWISE] f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
# Only rank 0 should print progress bar during capture # Only rank 0 should print progress bar during capture
if is_global_first_rank(): if is_global_first_rank():
@ -3585,6 +3594,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.calculate_reorder_batch_threshold() self.calculate_reorder_batch_threshold()
def initialize_cudagraph_capture(self) -> None: def initialize_cudagraph_capture(self) -> None:
"""
Resolve the cudagraph_mode when there are multiple attention
backends with potential conflicting CUDA graph support.
Then initialize the cudagraph_dispatcher based on the resolved
cudagraph_mode.
"""
min_cg_support = AttentionCGSupport.ALWAYS min_cg_support = AttentionCGSupport.ALWAYS
min_cg_builder_name = None min_cg_builder_name = None