mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 06:25:01 +08:00
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
f708bd4904
commit
f075693da7
@ -3,12 +3,11 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.utils import wait_for_gpu_memory_to_clear
|
from tests.utils import wait_for_gpu_memory_to_clear
|
||||||
|
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -33,89 +32,6 @@ def temporary_environ(env_vars):
|
|||||||
os.environ[k] = v
|
os.environ[k] = v
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BackendConfig:
|
|
||||||
name: str
|
|
||||||
env_vars: dict
|
|
||||||
comp_config: dict
|
|
||||||
specific_gpu_arch: Optional[tuple] = None
|
|
||||||
|
|
||||||
|
|
||||||
# Define all backend configurations of full cudagraph to be tested
|
|
||||||
backend_configs = {
|
|
||||||
# FA3 on Hopper
|
|
||||||
"FA3":
|
|
||||||
BackendConfig(name="FA3",
|
|
||||||
env_vars={
|
|
||||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
|
||||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
|
||||||
},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL",
|
|
||||||
},
|
|
||||||
specific_gpu_arch=(9, 0)),
|
|
||||||
# FlashMLA on Hopper
|
|
||||||
"FlashMLA":
|
|
||||||
BackendConfig(name="FlashMLA",
|
|
||||||
env_vars={
|
|
||||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
|
||||||
},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
||||||
},
|
|
||||||
specific_gpu_arch=(9, 0)),
|
|
||||||
# FlashAttention MLA on Hopper
|
|
||||||
"FlashAttentionMLA":
|
|
||||||
BackendConfig(name="FlashAttentionMLA",
|
|
||||||
env_vars={
|
|
||||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
|
||||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
|
||||||
},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
|
||||||
},
|
|
||||||
specific_gpu_arch=(9, 0)),
|
|
||||||
# Cutlass MLA on Blackwell
|
|
||||||
"CutlassMLA":
|
|
||||||
BackendConfig(
|
|
||||||
name="CutlassMLA",
|
|
||||||
env_vars={
|
|
||||||
"VLLM_USE_V1": "1",
|
|
||||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
|
||||||
"FORCE_NUM_KV_SPLITS":
|
|
||||||
"1", # TODO: remove this when hang issue is fixed
|
|
||||||
},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
||||||
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
|
|
||||||
},
|
|
||||||
specific_gpu_arch=(10, 0)),
|
|
||||||
# FA2
|
|
||||||
"FA2":
|
|
||||||
BackendConfig(name="FA2",
|
|
||||||
env_vars={
|
|
||||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
|
||||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
|
||||||
},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL",
|
|
||||||
}),
|
|
||||||
# Triton Attention
|
|
||||||
"TritonAttn":
|
|
||||||
BackendConfig(name="TritonAttn",
|
|
||||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL",
|
|
||||||
}),
|
|
||||||
# FlashInfer
|
|
||||||
"FlashInfer":
|
|
||||||
BackendConfig(name="FlashInfer",
|
|
||||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
test_params_full_cudagraph = []
|
test_params_full_cudagraph = []
|
||||||
|
|
||||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import pytest
|
|||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.config import CompilationConfig, VllmConfig
|
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||||
from vllm.utils import _is_torch_equal_or_newer
|
from vllm.utils import _is_torch_equal_or_newer
|
||||||
|
|
||||||
|
|
||||||
@ -106,7 +106,6 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
|
|||||||
def test_no_compilation(vllm_runner, monkeypatch):
|
def test_no_compilation(vllm_runner, monkeypatch):
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||||
|
|
||||||
with (
|
with (
|
||||||
compilation_counter.expect(num_graphs_seen=0,
|
compilation_counter.expect(num_graphs_seen=0,
|
||||||
dynamo_as_is_count=0),
|
dynamo_as_is_count=0),
|
||||||
@ -131,3 +130,67 @@ def test_enforce_eager(vllm_runner, monkeypatch):
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
gpu_memory_utilization=0.4) as _):
|
gpu_memory_utilization=0.4) as _):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_splitting_ops_dynamic():
|
||||||
|
# Default config
|
||||||
|
config = VllmConfig()
|
||||||
|
assert config.compilation_config.cudagraph_mode == \
|
||||||
|
CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
assert config.compilation_config.splitting_ops_contain_attention()
|
||||||
|
|
||||||
|
# When use_inductor_graph_partition=True
|
||||||
|
if _is_torch_equal_or_newer('2.9.0.dev'):
|
||||||
|
# inductor graph partition is only available in PyTorch 2.9+.
|
||||||
|
# this is a fast config check so we are not using pytest.skip.
|
||||||
|
config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
|
use_inductor_graph_partition=True,
|
||||||
|
splitting_ops=["silly_attention"]))
|
||||||
|
# should ignore splitting_ops
|
||||||
|
assert config.compilation_config.splitting_ops == []
|
||||||
|
|
||||||
|
# When attn_fusion pass enabled.
|
||||||
|
config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
|
pass_config={
|
||||||
|
"enable_attn_fusion": True,
|
||||||
|
"enable_noop": True
|
||||||
|
},
|
||||||
|
custom_ops=["+quant_fp8"],
|
||||||
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
))
|
||||||
|
assert config.compilation_config.splitting_ops == []
|
||||||
|
# cudagraph mode also fall back to FULL
|
||||||
|
assert config.compilation_config.cudagraph_mode == \
|
||||||
|
CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
# splitting_ops can not contain attention ops when attn_fusion
|
||||||
|
# pass enabled.
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
|
pass_config={
|
||||||
|
"enable_attn_fusion": True,
|
||||||
|
"enable_noop": True
|
||||||
|
},
|
||||||
|
custom_ops=["+quant_fp8"],
|
||||||
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
# work around for accessing all attntion ops
|
||||||
|
splitting_ops=CompilationConfig()._attention_ops,
|
||||||
|
))
|
||||||
|
|
||||||
|
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||||
|
if _is_torch_equal_or_newer('2.9.0.dev'):
|
||||||
|
config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
|
use_inductor_graph_partition=True,
|
||||||
|
pass_config={
|
||||||
|
"enable_attn_fusion": True,
|
||||||
|
"enable_noop": True
|
||||||
|
},
|
||||||
|
custom_ops=["+quant_fp8"],
|
||||||
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
))
|
||||||
|
assert config.compilation_config.splitting_ops == []
|
||||||
|
# enable_attn_fusion is directly support under
|
||||||
|
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||||
|
# is unchanged.
|
||||||
|
assert config.compilation_config.cudagraph_mode == \
|
||||||
|
CUDAGraphMode.PIECEWISE
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
"""Utility functions for attention-related v1 tests."""
|
"""Utility functions for attention-related v1 tests."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -260,3 +260,88 @@ def create_dummy_kv_cache(block_size: int,
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device)
|
device=device)
|
||||||
return kv_cache
|
return kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackendConfig:
|
||||||
|
name: str
|
||||||
|
env_vars: dict
|
||||||
|
comp_config: dict # compilation config
|
||||||
|
specific_gpu_arch: Optional[tuple] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Define all backend configurations of full cudagraph to be tested
|
||||||
|
full_cg_backend_configs = {
|
||||||
|
# FA3 on Hopper
|
||||||
|
"FA3":
|
||||||
|
BackendConfig(name="FA3",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||||
|
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||||
|
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||||
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(9, 0)),
|
||||||
|
# FlashMLA on Hopper
|
||||||
|
"FlashMLA":
|
||||||
|
BackendConfig(name="FlashMLA",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||||
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(9, 0)),
|
||||||
|
# Cutlass MLA on Blackwell
|
||||||
|
"CutlassMLA":
|
||||||
|
BackendConfig(
|
||||||
|
name="CutlassMLA",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_USE_V1": "1",
|
||||||
|
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||||
|
"FORCE_NUM_KV_SPLITS":
|
||||||
|
"1", # TODO: remove this when hang issue is fixed
|
||||||
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(10, 0)),
|
||||||
|
# FlashAttention MLA on Hopper
|
||||||
|
"FlashAttentionMLA":
|
||||||
|
BackendConfig(name="FlashAttentionMLA",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||||
|
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||||
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(9, 0)),
|
||||||
|
# FA2
|
||||||
|
"FA2":
|
||||||
|
BackendConfig(name="FA2",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||||
|
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||||
|
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||||
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
}),
|
||||||
|
# Triton Attention
|
||||||
|
"TritonAttn":
|
||||||
|
BackendConfig(name="TritonAttn",
|
||||||
|
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
}),
|
||||||
|
# FlashInfer
|
||||||
|
"FlashInfer":
|
||||||
|
BackendConfig(name="FlashInfer",
|
||||||
|
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|||||||
@ -45,38 +45,21 @@ def _create_vllm_config(compilation_config: CompilationConfig,
|
|||||||
class TestCudagraphDispatcher:
|
class TestCudagraphDispatcher:
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"params",
|
"case_id,cudagraph_mode_str,compilation_level",
|
||||||
[
|
[
|
||||||
# Test case 0: Full CG for mixed batches, no separate routine
|
# Test case 0: Full CG for mixed batches, no separate routine
|
||||||
{
|
(0, "FULL", CompilationLevel.NO_COMPILATION),
|
||||||
"case_id": 0,
|
|
||||||
"cudagraph_mode": "FULL",
|
|
||||||
"compilation_level": CompilationLevel.NO_COMPILATION,
|
|
||||||
},
|
|
||||||
# Test case 1: Full CG for uniform batches, piecewise for mixed
|
# Test case 1: Full CG for uniform batches, piecewise for mixed
|
||||||
{
|
(1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION),
|
||||||
"case_id": 1,
|
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
||||||
"compilation_level": CompilationLevel.PIECEWISE,
|
|
||||||
},
|
|
||||||
# Test case 2: Full CG for uniform batches, no CG for mixed
|
# Test case 2: Full CG for uniform batches, no CG for mixed
|
||||||
{
|
(2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
|
||||||
"case_id": 2,
|
|
||||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
|
||||||
"compilation_level": CompilationLevel.NO_COMPILATION,
|
|
||||||
},
|
|
||||||
# Test case 3: Piecewise for all
|
# Test case 3: Piecewise for all
|
||||||
{
|
(3, "PIECEWISE", CompilationLevel.PIECEWISE),
|
||||||
"case_id": 3,
|
|
||||||
"cudagraph_mode": "PIECEWISE",
|
|
||||||
"compilation_level": CompilationLevel.PIECEWISE,
|
|
||||||
},
|
|
||||||
])
|
])
|
||||||
def test_dispatcher(self, params):
|
def test_dispatcher(self, cudagraph_mode_str, compilation_level):
|
||||||
# Setup dispatcher
|
# Setup dispatcher
|
||||||
comp_config = CompilationConfig(
|
comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str,
|
||||||
cudagraph_mode=params["cudagraph_mode"],
|
level=compilation_level,
|
||||||
level=params["compilation_level"],
|
|
||||||
cudagraph_capture_sizes=[1, 8])
|
cudagraph_capture_sizes=[1, 8])
|
||||||
|
|
||||||
config = _create_vllm_config(comp_config, max_num_seqs=8)
|
config = _create_vllm_config(comp_config, max_num_seqs=8)
|
||||||
@ -86,11 +69,11 @@ class TestCudagraphDispatcher:
|
|||||||
uniform_decode_query_len=1)
|
uniform_decode_query_len=1)
|
||||||
|
|
||||||
# Verify the key is initialized correctly
|
# Verify the key is initialized correctly
|
||||||
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
|
||||||
else:
|
else:
|
||||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
||||||
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
|
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
|
||||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
|
||||||
else:
|
else:
|
||||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
||||||
@ -99,10 +82,10 @@ class TestCudagraphDispatcher:
|
|||||||
# 1. non-uniform batch, size in cudagraph size list
|
# 1. non-uniform batch, size in cudagraph size list
|
||||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||||
rt_mode, key = dispatcher.dispatch(desc_full_exact)
|
rt_mode, key = dispatcher.dispatch(desc_full_exact)
|
||||||
if params["cudagraph_mode"] == "FULL":
|
if cudagraph_mode_str == "FULL":
|
||||||
assert rt_mode == CUDAGraphMode.FULL
|
assert rt_mode == CUDAGraphMode.FULL
|
||||||
assert key == desc_full_exact
|
assert key == desc_full_exact
|
||||||
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||||
assert key == desc_full_exact
|
assert key == desc_full_exact
|
||||||
else:
|
else:
|
||||||
@ -111,15 +94,13 @@ class TestCudagraphDispatcher:
|
|||||||
# 2. uniform decode batch, size in cudagraph size list
|
# 2. uniform decode batch, size in cudagraph size list
|
||||||
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
|
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
|
||||||
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
|
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
|
||||||
if params["cudagraph_mode"] == "FULL":
|
if cudagraph_mode_str == "FULL":
|
||||||
assert rt_mode == CUDAGraphMode.FULL
|
assert rt_mode == CUDAGraphMode.FULL
|
||||||
assert key == desc_uniform_exact.non_uniform
|
assert key == desc_uniform_exact.non_uniform
|
||||||
elif params["cudagraph_mode"] in [
|
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
|
||||||
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
|
|
||||||
]:
|
|
||||||
assert rt_mode == CUDAGraphMode.FULL
|
assert rt_mode == CUDAGraphMode.FULL
|
||||||
assert key == desc_uniform_exact
|
assert key == desc_uniform_exact
|
||||||
elif params["cudagraph_mode"] == "PIECEWISE":
|
elif cudagraph_mode_str == "PIECEWISE":
|
||||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||||
assert key == desc_uniform_exact.non_uniform
|
assert key == desc_uniform_exact.non_uniform
|
||||||
else:
|
else:
|
||||||
@ -131,6 +112,16 @@ class TestCudagraphDispatcher:
|
|||||||
assert rt_mode == CUDAGraphMode.NONE
|
assert rt_mode == CUDAGraphMode.NONE
|
||||||
assert key is None
|
assert key is None
|
||||||
|
|
||||||
|
# 4. Cascade attention should have a fall back mode
|
||||||
|
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||||
|
rt_mode, key = dispatcher.dispatch(desc_full_exact,
|
||||||
|
use_cascade_attn=True)
|
||||||
|
if "PIECEWISE" in cudagraph_mode_str: # string contains check
|
||||||
|
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||||
|
assert key == desc_full_exact.non_uniform
|
||||||
|
else:
|
||||||
|
assert rt_mode == CUDAGraphMode.NONE
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||||
class TestCUDAGraphWrapper:
|
class TestCUDAGraphWrapper:
|
||||||
|
|||||||
@ -4,12 +4,11 @@ import contextlib
|
|||||||
import os
|
import os
|
||||||
import weakref
|
import weakref
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.utils import wait_for_gpu_memory_to_clear
|
from tests.utils import wait_for_gpu_memory_to_clear
|
||||||
|
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -34,74 +33,6 @@ def temporary_environ(env_vars):
|
|||||||
os.environ[k] = v
|
os.environ[k] = v
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BackendConfig:
|
|
||||||
name: str
|
|
||||||
env_vars: dict
|
|
||||||
comp_config: dict
|
|
||||||
specific_gpu_arch: Optional[tuple] = None
|
|
||||||
|
|
||||||
|
|
||||||
# Define all backend configurations of full cudagraph to be tested
|
|
||||||
backend_configs = {
|
|
||||||
# FA3 on Hopper
|
|
||||||
"FA3":
|
|
||||||
BackendConfig(name="FA3",
|
|
||||||
env_vars={
|
|
||||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
|
||||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
|
||||||
},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL",
|
|
||||||
},
|
|
||||||
specific_gpu_arch=(9, 0)),
|
|
||||||
# FlashMLA on Hopper
|
|
||||||
"FlashMLA":
|
|
||||||
BackendConfig(name="FlashMLA",
|
|
||||||
env_vars={
|
|
||||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
|
||||||
},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
||||||
},
|
|
||||||
specific_gpu_arch=(9, 0)),
|
|
||||||
# FlashAttention MLA on Hopper
|
|
||||||
"FlashAttentionMLA":
|
|
||||||
BackendConfig(name="FlashAttentionMLA",
|
|
||||||
env_vars={
|
|
||||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
|
||||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
|
||||||
},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
|
||||||
},
|
|
||||||
specific_gpu_arch=(9, 0)),
|
|
||||||
# FA2
|
|
||||||
"FA2":
|
|
||||||
BackendConfig(name="FA2",
|
|
||||||
env_vars={
|
|
||||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
|
||||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
|
||||||
},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
||||||
}),
|
|
||||||
# Triton Attention
|
|
||||||
"TritonAttn":
|
|
||||||
BackendConfig(name="TritonAttn",
|
|
||||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
||||||
}),
|
|
||||||
# FlashInfer
|
|
||||||
"FlashInfer":
|
|
||||||
BackendConfig(name="FlashInfer",
|
|
||||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
|
||||||
comp_config={
|
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
# test attention backend and cudagraph_mode combo
|
# test attention backend and cudagraph_mode combo
|
||||||
# (backend_name, cudagraph_mode, supported)
|
# (backend_name, cudagraph_mode, supported)
|
||||||
combo_cases_1 = [
|
combo_cases_1 = [
|
||||||
@ -114,9 +45,10 @@ combo_cases_1 = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("combo_case", combo_cases_1)
|
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported",
|
||||||
def test_backend_and_cudagraph_mode_combo(combo_case):
|
combo_cases_1)
|
||||||
backend_name, cudagraph_mode, supported = combo_case
|
def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode,
|
||||||
|
supported):
|
||||||
if backend_name == "FlashInfer":
|
if backend_name == "FlashInfer":
|
||||||
try:
|
try:
|
||||||
import flashinfer # noqa: F401
|
import flashinfer # noqa: F401
|
||||||
@ -142,7 +74,7 @@ def test_backend_and_cudagraph_mode_combo(combo_case):
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
level=3, cudagraph_mode=cudagraph_mode))
|
level=3, cudagraph_mode=cudagraph_mode))
|
||||||
llm.generate(["Hello, my name is"] * 10)
|
llm.generate(["Hello, my name is"] * 10)
|
||||||
|
# when above code raises, `llm` may be undefined, so we need to catch that
|
||||||
try:
|
try:
|
||||||
llm = weakref.proxy(llm)
|
llm = weakref.proxy(llm)
|
||||||
del llm
|
del llm
|
||||||
@ -173,7 +105,8 @@ combo_cases_2 = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("combo_case", combo_cases_2)
|
@pytest.mark.parametrize("backend_name,cudagraph_mode,compilation_level,"\
|
||||||
|
"supported", combo_cases_2)
|
||||||
def test_cudagraph_compilation_combo(combo_case):
|
def test_cudagraph_compilation_combo(combo_case):
|
||||||
backend_name, cudagraph_mode, compilation_level, supported\
|
backend_name, cudagraph_mode, compilation_level, supported\
|
||||||
= combo_case
|
= combo_case
|
||||||
@ -192,6 +125,7 @@ def test_cudagraph_compilation_combo(combo_case):
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
level=compilation_level, cudagraph_mode=cudagraph_mode))
|
level=compilation_level, cudagraph_mode=cudagraph_mode))
|
||||||
llm.generate(["Hello, my name is"] * 10)
|
llm.generate(["Hello, my name is"] * 10)
|
||||||
|
# when above code raises, `llm` may be undefined, so we need to catch that
|
||||||
try:
|
try:
|
||||||
llm = weakref.proxy(llm)
|
llm = weakref.proxy(llm)
|
||||||
del llm
|
del llm
|
||||||
|
|||||||
@ -340,15 +340,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
num_graphs=len(self.compile_submod_names),
|
num_graphs=len(self.compile_submod_names),
|
||||||
runtime_shape=None)
|
runtime_shape=None)
|
||||||
# Lazy import here to avoid circular import
|
# Lazy import here to avoid circular import
|
||||||
from .cuda_piecewise_backend import PiecewiseBackend
|
from .piecewise_backend import PiecewiseBackend
|
||||||
|
|
||||||
piecewise_backend = PiecewiseBackend(
|
piecewise_backend = PiecewiseBackend(
|
||||||
submod, self.vllm_config, index,
|
submod, self.vllm_config, index,
|
||||||
len(self.compile_submod_names), sym_shape_indices,
|
len(self.compile_submod_names), sym_shape_indices,
|
||||||
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
||||||
|
|
||||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
if (self.compilation_config.cudagraph_mode.\
|
||||||
and
|
has_piecewise_cudagraphs() and
|
||||||
not self.compilation_config.use_inductor_graph_partition):
|
not self.compilation_config.use_inductor_graph_partition):
|
||||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||||
# the whole subgraph with a static graph wrapper.
|
# the whole subgraph with a static graph wrapper.
|
||||||
|
|||||||
@ -336,7 +336,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
|||||||
from vllm.config import CUDAGraphMode
|
from vllm.config import CUDAGraphMode
|
||||||
|
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||||
and compilation_config.use_inductor_graph_partition):
|
and compilation_config.use_inductor_graph_partition):
|
||||||
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||||
|
|
||||||
@ -365,7 +365,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
|||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||||
and compilation_config.use_inductor_graph_partition):
|
and compilation_config.use_inductor_graph_partition):
|
||||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||||
|
|
||||||
|
|||||||
@ -459,15 +459,22 @@ class VllmConfig:
|
|||||||
"to True to enable.")
|
"to True to enable.")
|
||||||
current_platform.check_and_update_config(self)
|
current_platform.check_and_update_config(self)
|
||||||
|
|
||||||
# final check of cudagraph mode after platform-specific update
|
# Do this after all the updates to compilation_config.level
|
||||||
|
if envs.VLLM_USE_V1 and \
|
||||||
|
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||||
|
self.compilation_config.set_splitting_ops_for_v1()
|
||||||
|
|
||||||
|
# final check of cudagraph mode after all possible updates
|
||||||
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
|
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
|
||||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\
|
||||||
and self.model_config is not None and \
|
and self.model_config is not None and \
|
||||||
not self.model_config.disable_cascade_attn:
|
not self.model_config.disable_cascade_attn and\
|
||||||
logger.info("CUDAGraphMode.FULL is not supported with "
|
not self.compilation_config.cudagraph_mode.\
|
||||||
"cascade attention currently. Disabling cascade"
|
has_piecewise_cudagraphs():
|
||||||
"attention.")
|
logger.warning_once(
|
||||||
self.model_config.disable_cascade_attn = True
|
"No piecewise cudagraph for executing cascade attention."
|
||||||
|
" Will fall back to eager execution if a batch runs "
|
||||||
|
"into cascade attentions")
|
||||||
|
|
||||||
if self.compilation_config.cudagraph_mode\
|
if self.compilation_config.cudagraph_mode\
|
||||||
.requires_piecewise_compilation():
|
.requires_piecewise_compilation():
|
||||||
@ -477,6 +484,12 @@ class VllmConfig:
|
|||||||
"when cudagraph_mode piecewise cudagraphs is used, "\
|
"when cudagraph_mode piecewise cudagraphs is used, "\
|
||||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||||
|
|
||||||
|
# final migrate the deprecated flags
|
||||||
|
self.compilation_config.use_cudagraph = self.compilation_config.\
|
||||||
|
cudagraph_mode!= CUDAGraphMode.NONE
|
||||||
|
self.compilation_config.full_cuda_graph = self.compilation_config.\
|
||||||
|
cudagraph_mode.has_full_cudagraphs()
|
||||||
|
|
||||||
if self.parallel_config.enable_dbo:
|
if self.parallel_config.enable_dbo:
|
||||||
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||||
assert a2a_backend in \
|
assert a2a_backend in \
|
||||||
@ -487,14 +500,14 @@ class VllmConfig:
|
|||||||
"variable to deepep_low_latency or deepep_high_throughput and "\
|
"variable to deepep_low_latency or deepep_high_throughput and "\
|
||||||
"install the DeepEP kernels."
|
"install the DeepEP kernels."
|
||||||
|
|
||||||
|
if not self.model_config.disable_cascade_attn:
|
||||||
|
self.model_config.disable_cascade_attn = True
|
||||||
|
logger.warning_once(
|
||||||
|
"Disabling cascade attention when DBO is enabled.")
|
||||||
|
|
||||||
if not self.instance_id:
|
if not self.instance_id:
|
||||||
self.instance_id = random_uuid()[:5]
|
self.instance_id = random_uuid()[:5]
|
||||||
|
|
||||||
# Do this after all the updates to compilation_config.level
|
|
||||||
if envs.VLLM_USE_V1 and \
|
|
||||||
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
|
||||||
self.compilation_config.set_splitting_ops_for_v1()
|
|
||||||
|
|
||||||
if (envs.VLLM_USE_V1
|
if (envs.VLLM_USE_V1
|
||||||
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
||||||
# logger should only print warning message for hybrid models. As we
|
# logger should only print warning message for hybrid models. As we
|
||||||
|
|||||||
@ -61,9 +61,17 @@ class CUDAGraphMode(enum.Enum):
|
|||||||
def has_full_cudagraphs(self) -> bool:
|
def has_full_cudagraphs(self) -> bool:
|
||||||
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
|
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
def has_piecewise_cudagraphs(self) -> bool:
|
||||||
|
return self.requires_piecewise_compilation()
|
||||||
|
|
||||||
def separate_routine(self) -> bool:
|
def separate_routine(self) -> bool:
|
||||||
return isinstance(self.value, tuple)
|
return isinstance(self.value, tuple)
|
||||||
|
|
||||||
|
def valid_runtime_modes(self) -> bool:
|
||||||
|
return self in [
|
||||||
|
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -269,7 +277,8 @@ class CompilationConfig:
|
|||||||
Note that this is orthogonal to the cudagraph capture logic
|
Note that this is orthogonal to the cudagraph capture logic
|
||||||
outside of compilation.
|
outside of compilation.
|
||||||
Warning: This flag is deprecated and will be removed in the next major or
|
Warning: This flag is deprecated and will be removed in the next major or
|
||||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE
|
||||||
|
instead.
|
||||||
"""
|
"""
|
||||||
cudagraph_num_of_warmups: int = 0
|
cudagraph_num_of_warmups: int = 0
|
||||||
"""Number of warmup runs for cudagraph.
|
"""Number of warmup runs for cudagraph.
|
||||||
@ -294,7 +303,8 @@ class CompilationConfig:
|
|||||||
flag cannot be used together with splitting_ops. This may provide
|
flag cannot be used together with splitting_ops. This may provide
|
||||||
performance benefits for smaller models.
|
performance benefits for smaller models.
|
||||||
Warning: This flag is deprecated and will be removed in the next major or
|
Warning: This flag is deprecated and will be removed in the next major or
|
||||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
|
||||||
|
FULL_AND_PIECEWISE instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
use_inductor_graph_partition: bool = False
|
use_inductor_graph_partition: bool = False
|
||||||
@ -464,7 +474,8 @@ class CompilationConfig:
|
|||||||
if not self.use_cudagraph:
|
if not self.use_cudagraph:
|
||||||
logger.warning("use_cudagraph is deprecated, use "
|
logger.warning("use_cudagraph is deprecated, use "
|
||||||
"cudagraph_mode=NONE instead.")
|
"cudagraph_mode=NONE instead.")
|
||||||
if self.cudagraph_mode is not None:
|
if self.cudagraph_mode is not None and \
|
||||||
|
self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"use_cudagraph and cudagraph_mode are mutually"
|
"use_cudagraph and cudagraph_mode are mutually"
|
||||||
" exclusive, prefer cudagraph_mode since "
|
" exclusive, prefer cudagraph_mode since "
|
||||||
@ -473,7 +484,8 @@ class CompilationConfig:
|
|||||||
if self.full_cuda_graph:
|
if self.full_cuda_graph:
|
||||||
logger.warning("full_cuda_graph is deprecated, use "
|
logger.warning("full_cuda_graph is deprecated, use "
|
||||||
"cudagraph_mode=FULL instead.")
|
"cudagraph_mode=FULL instead.")
|
||||||
if self.cudagraph_mode is not None:
|
if self.cudagraph_mode is not None and \
|
||||||
|
not self.cudagraph_mode.has_full_cudagraphs():
|
||||||
raise ValueError("full_cuda_graph and cudagraph_mode are "
|
raise ValueError("full_cuda_graph and cudagraph_mode are "
|
||||||
"mutually exclusive, prefer cudagraph_mode "
|
"mutually exclusive, prefer cudagraph_mode "
|
||||||
"since full_cuda_graph is deprecated.")
|
"since full_cuda_graph is deprecated.")
|
||||||
@ -570,49 +582,76 @@ class CompilationConfig:
|
|||||||
"set_splitting_ops_for_v1 should only be called when "
|
"set_splitting_ops_for_v1 should only be called when "
|
||||||
"level is CompilationLevel.PIECEWISE")
|
"level is CompilationLevel.PIECEWISE")
|
||||||
|
|
||||||
use_inductor_graph_partition_msg = (
|
if self.use_inductor_graph_partition:
|
||||||
"When use_inductor_graph_partition=True, splitting_ops "
|
self.set_splitting_ops_for_inductor_graph_partition()
|
||||||
"are ignored and set to an empty list. Instead, "
|
return
|
||||||
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
|
|
||||||
"used to annotate custom ops for graph partition.")
|
if self.pass_config.enable_attn_fusion:
|
||||||
|
# here use_inductor_graph_partition is False
|
||||||
|
self.set_splitting_ops_for_attn_fusion()
|
||||||
|
return
|
||||||
|
|
||||||
if self.splitting_ops is None:
|
if self.splitting_ops is None:
|
||||||
if self.use_inductor_graph_partition:
|
|
||||||
# When using inductor graph partition, we set splitting_ops
|
|
||||||
# to be empty and rely on torch._C.Tag.cudagraph_unsafe to
|
|
||||||
# annotate custom ops as splitting ops.
|
|
||||||
logger.warning_once(use_inductor_graph_partition_msg)
|
|
||||||
self.splitting_ops = []
|
|
||||||
else:
|
|
||||||
# NOTE: When using full cudagraph, instead of setting an empty
|
# NOTE: When using full cudagraph, instead of setting an empty
|
||||||
# list and capture the full cudagraph inside the flattened fx
|
# list and capture the full cudagraph inside the flattened fx
|
||||||
# graph, we keep the piecewise fx graph structure but capture
|
# graph, we keep the piecewise fx graph structure but capture
|
||||||
# the full cudagraph outside the fx graph. This reduces some
|
# the full cudagraph outside the fx graph. This reduces some
|
||||||
# cpu overhead when the runtime batch_size is not cudagraph
|
# cpu overhead when the runtime batch_size is not cudagraph
|
||||||
# captured. see https://github.com/vllm-project/vllm/pull/20059
|
# captured. see https://github.com/vllm-project/vllm/pull/20059
|
||||||
# for details. make a copy to avoid mutating the class-level
|
# for details. Make a copy to avoid mutating the class-level
|
||||||
# list via reference.
|
# list via reference.
|
||||||
self.splitting_ops = list(self._attention_ops)
|
self.splitting_ops = list(self._attention_ops)
|
||||||
elif len(self.splitting_ops) == 0:
|
elif len(self.splitting_ops) == 0:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Using piecewise compilation with empty "
|
"Using piecewise compilation with empty splitting_ops")
|
||||||
"splitting_ops and use_inductor_graph_partition"
|
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||||
f"={self.use_inductor_graph_partition}.")
|
|
||||||
if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
|
||||||
and not self.use_inductor_graph_partition):
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"When compilation level is piecewise with empty "
|
"Piecewise compilation with empty splitting_ops do not" \
|
||||||
"splitting_ops, PIECEWISE cudagraph_mode will be "
|
"contains piecewise cudagraph. Setting cudagraph_"
|
||||||
"treated as FULL cudagraph_mode. Please ensure you are "
|
"mode to NONE. Hint: If you are using attention backends "
|
||||||
"using attention backends that support cudagraph or set "
|
"that support cudagraph, consider manually setting "
|
||||||
"cudagraph_mode to NONE explicitly if encountering "
|
"cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
|
||||||
"any problems.")
|
"full cudagraphs.")
|
||||||
|
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
|
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
||||||
|
logger.warning_once(
|
||||||
|
"Piecewise compilation with empty splitting_ops do not "
|
||||||
|
"contains piecewise cudagraph. Setting cudagraph_mode "
|
||||||
|
"to FULL.")
|
||||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||||
self.splitting_ops = []
|
self.splitting_ops = []
|
||||||
elif self.use_inductor_graph_partition:
|
|
||||||
|
def set_splitting_ops_for_inductor_graph_partition(self):
|
||||||
|
assert self.use_inductor_graph_partition
|
||||||
|
use_inductor_graph_partition_msg = (
|
||||||
|
"When use_inductor_graph_partition=True, splitting_ops "
|
||||||
|
"are ignored and set to an empty list. Instead, "
|
||||||
|
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
|
||||||
|
"used to annotate custom ops for graph partition.")
|
||||||
|
if self.splitting_ops is not None and \
|
||||||
|
len(self.splitting_ops) > 0:
|
||||||
logger.warning_once(use_inductor_graph_partition_msg)
|
logger.warning_once(use_inductor_graph_partition_msg)
|
||||||
self.splitting_ops = []
|
self.splitting_ops = []
|
||||||
|
|
||||||
|
def set_splitting_ops_for_attn_fusion(self):
|
||||||
|
assert self.pass_config.enable_attn_fusion
|
||||||
|
if self.splitting_ops is None:
|
||||||
|
self.splitting_ops = []
|
||||||
|
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||||
|
logger.warning_once(
|
||||||
|
"enable_attn_fusion is incompatible with piecewise "
|
||||||
|
"cudagraph when use_inductor_graph_partition is off."
|
||||||
|
"In this case, splitting_ops will be set to empty "
|
||||||
|
"list, and cudagraph_mode will be set to FULL. "
|
||||||
|
"Please ensure you are using attention backends that "
|
||||||
|
"support cudagraph or set cudagraph_mode to NONE "
|
||||||
|
"explicitly if encountering any problems.")
|
||||||
|
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
assert not self.splitting_ops_contain_attention(), (
|
||||||
|
"attention ops should not be in splitting_ops "
|
||||||
|
"when enable_attn_fusion is True")
|
||||||
|
|
||||||
def splitting_ops_contain_attention(self) -> bool:
|
def splitting_ops_contain_attention(self) -> bool:
|
||||||
return self.splitting_ops is not None and all(
|
return self.splitting_ops is not None and all(
|
||||||
op in self.splitting_ops for op in self._attention_ops)
|
op in self.splitting_ops for op in self._attention_ops)
|
||||||
|
|||||||
@ -246,8 +246,7 @@ class ForwardContext:
|
|||||||
ubatch_slices: Optional[UBatchSlices] = None
|
ubatch_slices: Optional[UBatchSlices] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert self.cudagraph_runtime_mode in [
|
assert self.cudagraph_runtime_mode.valid_runtime_modes(), \
|
||||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
|
||||||
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
|
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -22,10 +22,10 @@ class CudagraphDispatcher:
|
|||||||
|
|
||||||
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
|
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
|
||||||
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
|
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
|
||||||
based on the input key. After dispatching (communicate via forward context),
|
based on the input key. After dispatching (communicated via forward
|
||||||
the cudagraph wrappers will trust the dispatch key to do either capturing
|
context), the cudagraph wrappers will trust the dispatch key to either
|
||||||
or replaying (if mode matched), or pass through to the underlying runnable
|
capture or replay (if the mode matches), or pass through to the underlying
|
||||||
without cudagraph (if mode no match or mode is NONE).
|
runnable without cudagraph (if the mode does not match or mode is NONE).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
@ -57,19 +57,15 @@ class CudagraphDispatcher:
|
|||||||
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
|
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
|
||||||
batch_descriptor: BatchDescriptor):
|
batch_descriptor: BatchDescriptor):
|
||||||
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
||||||
f"Invalid cudagraph runtime mode: {runtime_mode}"
|
f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
|
||||||
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
|
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
|
||||||
|
|
||||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
|
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
|
||||||
uniform_decode_query_len: int):
|
uniform_decode_query_len: int):
|
||||||
# This should be called only after attention backend is initialized.
|
# This should be called only after attention backend is initialized.
|
||||||
|
|
||||||
# Note: we create all valid keys possible for cudagraph but do not
|
# Note: we create all valid keys for cudagraph here but do not
|
||||||
# guarantee all keys would be used. For example, we create keys for
|
# guarantee all keys would be used. For example, if we allow lazy
|
||||||
# piecewise cudagraphs when it is piecewise compilation, which is always
|
|
||||||
# valid, but for attention backend support unified routine, we may not
|
|
||||||
# trigger capturing/replaying the piecewise cudagraphs depending on
|
|
||||||
# CompilationConfig.cudagraph_mode. In addition, if we allow lazy
|
|
||||||
# capturing in future PR, some keys may never be triggered.
|
# capturing in future PR, some keys may never be triggered.
|
||||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||||
for bs in self.compilation_config.cudagraph_capture_sizes:
|
for bs in self.compilation_config.cudagraph_capture_sizes:
|
||||||
@ -94,10 +90,13 @@ class CudagraphDispatcher:
|
|||||||
self.keys_initialized = True
|
self.keys_initialized = True
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, batch_descriptor: BatchDescriptor
|
self,
|
||||||
|
batch_descriptor: BatchDescriptor,
|
||||||
|
use_cascade_attn: bool = False
|
||||||
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
|
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
|
||||||
"""
|
"""
|
||||||
Given a batch descriptor, dispatch to a cudagraph mode.
|
Given conditions(e.g.,batch descriptor and if using cascade attention),
|
||||||
|
dispatch to a cudagraph runtime mode and the valid batch descriptor.
|
||||||
A new batch descriptor is returned as we might dispatch a uniform batch
|
A new batch descriptor is returned as we might dispatch a uniform batch
|
||||||
to a graph that supports a more general batch (uniform to non-uniform).
|
to a graph that supports a more general batch (uniform to non-uniform).
|
||||||
"""
|
"""
|
||||||
@ -107,12 +106,14 @@ class CudagraphDispatcher:
|
|||||||
"initialized. No cudagraph will be used.")
|
"initialized. No cudagraph will be used.")
|
||||||
return CUDAGraphMode.NONE, None
|
return CUDAGraphMode.NONE, None
|
||||||
|
|
||||||
|
non_uniform_key = batch_descriptor.non_uniform
|
||||||
|
# if a batch use cascade attention, bypass checking full cudagraphs
|
||||||
|
if not use_cascade_attn:
|
||||||
# check if key exists for full cudagraph
|
# check if key exists for full cudagraph
|
||||||
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||||
return CUDAGraphMode.FULL, batch_descriptor
|
return CUDAGraphMode.FULL, batch_descriptor
|
||||||
|
|
||||||
# otherwise, check if non-uniform key exists
|
# otherwise, check if non-uniform key exists
|
||||||
non_uniform_key = batch_descriptor.non_uniform
|
|
||||||
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||||
return CUDAGraphMode.FULL, non_uniform_key
|
return CUDAGraphMode.FULL, non_uniform_key
|
||||||
|
|
||||||
|
|||||||
@ -923,11 +923,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
||||||
Optional[SpecDecodeMetadata], np.ndarray,
|
Optional[SpecDecodeMetadata], np.ndarray,
|
||||||
Optional[CommonAttentionMetadata], int, Optional[UBatchSlices],
|
Optional[CommonAttentionMetadata], int, Optional[UBatchSlices],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor], bool]:
|
||||||
"""
|
"""
|
||||||
:return: tuple[
|
:return: tuple[
|
||||||
attn_metadata: layer-to-attention_metadata mapping,
|
attn_metadata: layer-to-attention_metadata mapping,
|
||||||
logits_indices, spec_decode_metadata
|
logits_indices, spec_decode_metadata,
|
||||||
|
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
||||||
|
max_num_scheduled_tokens, use_cascade_attn
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
@ -1135,6 +1137,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
attn_metadata: PerLayerAttnMetadata = {}
|
attn_metadata: PerLayerAttnMetadata = {}
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||||
|
use_cascade_attn = False
|
||||||
|
|
||||||
# Used in the below loop.
|
# Used in the below loop.
|
||||||
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||||
@ -1251,9 +1254,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
**extra_attn_metadata_args)
|
**extra_attn_metadata_args)
|
||||||
|
use_cascade_attn |= getattr(attn_metadata_i, "use_cascade",
|
||||||
|
False)
|
||||||
for layer_name in attn_group.layer_names:
|
for layer_name in attn_group.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
|
# disable cascade attention when DBO
|
||||||
|
if ubatch_slices is not None:
|
||||||
|
use_cascade_attn = False
|
||||||
|
|
||||||
# Hot-Swap lora model
|
# Hot-Swap lora model
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||||
@ -1261,7 +1270,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
return (attn_metadata, logits_indices, spec_decode_metadata,
|
return (attn_metadata, logits_indices, spec_decode_metadata,
|
||||||
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
||||||
max_num_scheduled_tokens, ubatch_slices,
|
max_num_scheduled_tokens, ubatch_slices,
|
||||||
num_tokens_after_padding)
|
num_tokens_after_padding, use_cascade_attn)
|
||||||
|
|
||||||
def _compute_cascade_attn_prefix_len(
|
def _compute_cascade_attn_prefix_len(
|
||||||
self,
|
self,
|
||||||
@ -2251,8 +2260,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||||
max_query_len, ubatch_slices, num_tokens_after_padding
|
max_query_len, ubatch_slices, num_tokens_after_padding,
|
||||||
) = self._prepare_inputs(scheduler_output)
|
use_cascade_attn) = self._prepare_inputs(scheduler_output)
|
||||||
|
|
||||||
(
|
(
|
||||||
num_scheduled_tokens,
|
num_scheduled_tokens,
|
||||||
@ -2273,7 +2282,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||||
uniform_decode=uniform_decode)
|
uniform_decode=uniform_decode)
|
||||||
cudagraph_runtime_mode, batch_descriptor = \
|
cudagraph_runtime_mode, batch_descriptor = \
|
||||||
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
self.cudagraph_dispatcher.dispatch(batch_descriptor,
|
||||||
|
use_cascade_attn)
|
||||||
|
|
||||||
# This is currently to get around the assert in the DPMetadata
|
# This is currently to get around the assert in the DPMetadata
|
||||||
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
||||||
@ -2701,16 +2711,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
"Cannot reload weights before model is loaded."
|
"Cannot reload weights before model is loaded."
|
||||||
model_loader = get_model_loader(self.load_config)
|
model_loader = get_model_loader(self.load_config)
|
||||||
logger.info("Reloading weights inplace...")
|
logger.info("Reloading weights inplace...")
|
||||||
model = self.get_model()
|
model_loader.load_weights(self.get_model(),
|
||||||
model_loader.load_weights(model, model_config=self.model_config)
|
model_config=self.model_config)
|
||||||
|
|
||||||
def save_tensorized_model(
|
def save_tensorized_model(
|
||||||
self,
|
self,
|
||||||
tensorizer_config: "TensorizerConfig",
|
tensorizer_config: "TensorizerConfig",
|
||||||
) -> None:
|
) -> None:
|
||||||
model = self.get_model()
|
|
||||||
TensorizerLoader.save_model(
|
TensorizerLoader.save_model(
|
||||||
model,
|
self.get_model(),
|
||||||
tensorizer_config=tensorizer_config,
|
tensorizer_config=tensorizer_config,
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
)
|
)
|
||||||
@ -2926,9 +2935,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
(1 token) and prefill (multiple tokens) requests.
|
(1 token) and prefill (multiple tokens) requests.
|
||||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||||
"""
|
"""
|
||||||
assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
|
assert cudagraph_runtime_mode is None or \
|
||||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
cudagraph_runtime_mode.valid_runtime_modes()
|
||||||
}
|
|
||||||
|
|
||||||
# If cudagraph_mode.decode_mode() == FULL and
|
# If cudagraph_mode.decode_mode() == FULL and
|
||||||
# cudagraph_mode.separate_routine(). This means that we are using
|
# cudagraph_mode.separate_routine(). This means that we are using
|
||||||
@ -3113,7 +3121,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# filter out the valid batch descriptor
|
# filter out the valid batch descriptor
|
||||||
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
|
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
|
||||||
BatchDescriptor(num_tokens=num_tokens,
|
BatchDescriptor(num_tokens=num_tokens,
|
||||||
uniform_decode=uniform_decode))
|
uniform_decode=uniform_decode)) \
|
||||||
|
if not is_profile else (CUDAGraphMode.NONE, None)
|
||||||
if cudagraph_runtime_mode is not None:
|
if cudagraph_runtime_mode is not None:
|
||||||
# we allow forcing NONE when the dispatcher disagrees to support
|
# we allow forcing NONE when the dispatcher disagrees to support
|
||||||
# warm ups for cudagraph capture
|
# warm ups for cudagraph capture
|
||||||
@ -3453,8 +3462,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
cudagraph_runtime_mode: CUDAGraphMode,
|
cudagraph_runtime_mode: CUDAGraphMode,
|
||||||
uniform_decode: bool):
|
uniform_decode: bool):
|
||||||
assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
|
assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
|
||||||
cudagraph_runtime_mode in [CUDAGraphMode.FULL,
|
cudagraph_runtime_mode.valid_runtime_modes(), \
|
||||||
CUDAGraphMode.PIECEWISE]
|
f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
|
||||||
|
|
||||||
# Only rank 0 should print progress bar during capture
|
# Only rank 0 should print progress bar during capture
|
||||||
if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
@ -3585,6 +3594,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.calculate_reorder_batch_threshold()
|
self.calculate_reorder_batch_threshold()
|
||||||
|
|
||||||
def initialize_cudagraph_capture(self) -> None:
|
def initialize_cudagraph_capture(self) -> None:
|
||||||
|
"""
|
||||||
|
Resolve the cudagraph_mode when there are multiple attention
|
||||||
|
backends with potential conflicting CUDA graph support.
|
||||||
|
Then initialize the cudagraph_dispatcher based on the resolved
|
||||||
|
cudagraph_mode.
|
||||||
|
"""
|
||||||
min_cg_support = AttentionCGSupport.ALWAYS
|
min_cg_support = AttentionCGSupport.ALWAYS
|
||||||
min_cg_builder_name = None
|
min_cg_builder_name = None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user