mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:36:29 +08:00
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
f708bd4904
commit
f075693da7
@ -3,12 +3,11 @@
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import wait_for_gpu_memory_to_clear
|
||||
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
@ -33,89 +32,6 @@ def temporary_environ(env_vars):
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict
|
||||
specific_gpu_arch: Optional[tuple] = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA":
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA":
|
||||
BackendConfig(
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
|
||||
},
|
||||
specific_gpu_arch=(10, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
}
|
||||
|
||||
test_params_full_cudagraph = []
|
||||
|
||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.utils import _is_torch_equal_or_newer
|
||||
|
||||
|
||||
@ -106,7 +106,6 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
|
||||
def test_no_compilation(vllm_runner, monkeypatch):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||
|
||||
with (
|
||||
compilation_counter.expect(num_graphs_seen=0,
|
||||
dynamo_as_is_count=0),
|
||||
@ -131,3 +130,67 @@ def test_enforce_eager(vllm_runner, monkeypatch):
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.4) as _):
|
||||
pass
|
||||
|
||||
|
||||
def test_splitting_ops_dynamic():
|
||||
# Default config
|
||||
config = VllmConfig()
|
||||
assert config.compilation_config.cudagraph_mode == \
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
|
||||
# When use_inductor_graph_partition=True
|
||||
if _is_torch_equal_or_newer('2.9.0.dev'):
|
||||
# inductor graph partition is only available in PyTorch 2.9+.
|
||||
# this is a fast config check so we are not using pytest.skip.
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["silly_attention"]))
|
||||
# should ignore splitting_ops
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
|
||||
# When attn_fusion pass enabled.
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
pass_config={
|
||||
"enable_attn_fusion": True,
|
||||
"enable_noop": True
|
||||
},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
))
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# cudagraph mode also fall back to FULL
|
||||
assert config.compilation_config.cudagraph_mode == \
|
||||
CUDAGraphMode.FULL
|
||||
|
||||
# splitting_ops can not contain attention ops when attn_fusion
|
||||
# pass enabled.
|
||||
with pytest.raises(AssertionError):
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
pass_config={
|
||||
"enable_attn_fusion": True,
|
||||
"enable_noop": True
|
||||
},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
splitting_ops=CompilationConfig()._attention_ops,
|
||||
))
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
if _is_torch_equal_or_newer('2.9.0.dev'):
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={
|
||||
"enable_attn_fusion": True,
|
||||
"enable_noop": True
|
||||
},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
))
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# enable_attn_fusion is directly support under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Utility functions for attention-related v1 tests."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -260,3 +260,88 @@ def create_dummy_kv_cache(block_size: int,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
return kv_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict # compilation config
|
||||
specific_gpu_arch: Optional[tuple] = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
full_cg_backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA":
|
||||
BackendConfig(
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(10, 0)),
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA":
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
}
|
||||
|
||||
@ -45,39 +45,22 @@ def _create_vllm_config(compilation_config: CompilationConfig,
|
||||
class TestCudagraphDispatcher:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
"case_id,cudagraph_mode_str,compilation_level",
|
||||
[
|
||||
# Test case 0: Full CG for mixed batches, no separate routine
|
||||
{
|
||||
"case_id": 0,
|
||||
"cudagraph_mode": "FULL",
|
||||
"compilation_level": CompilationLevel.NO_COMPILATION,
|
||||
},
|
||||
(0, "FULL", CompilationLevel.NO_COMPILATION),
|
||||
# Test case 1: Full CG for uniform batches, piecewise for mixed
|
||||
{
|
||||
"case_id": 1,
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
"compilation_level": CompilationLevel.PIECEWISE,
|
||||
},
|
||||
(1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION),
|
||||
# Test case 2: Full CG for uniform batches, no CG for mixed
|
||||
{
|
||||
"case_id": 2,
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"compilation_level": CompilationLevel.NO_COMPILATION,
|
||||
},
|
||||
(2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
|
||||
# Test case 3: Piecewise for all
|
||||
{
|
||||
"case_id": 3,
|
||||
"cudagraph_mode": "PIECEWISE",
|
||||
"compilation_level": CompilationLevel.PIECEWISE,
|
||||
},
|
||||
(3, "PIECEWISE", CompilationLevel.PIECEWISE),
|
||||
])
|
||||
def test_dispatcher(self, params):
|
||||
def test_dispatcher(self, cudagraph_mode_str, compilation_level):
|
||||
# Setup dispatcher
|
||||
comp_config = CompilationConfig(
|
||||
cudagraph_mode=params["cudagraph_mode"],
|
||||
level=params["compilation_level"],
|
||||
cudagraph_capture_sizes=[1, 8])
|
||||
comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str,
|
||||
level=compilation_level,
|
||||
cudagraph_capture_sizes=[1, 8])
|
||||
|
||||
config = _create_vllm_config(comp_config, max_num_seqs=8)
|
||||
dispatcher = CudagraphDispatcher(config)
|
||||
@ -86,11 +69,11 @@ class TestCudagraphDispatcher:
|
||||
uniform_decode_query_len=1)
|
||||
|
||||
# Verify the key is initialized correctly
|
||||
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
||||
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
|
||||
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
||||
@ -99,10 +82,10 @@ class TestCudagraphDispatcher:
|
||||
# 1. non-uniform batch, size in cudagraph size list
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||
rt_mode, key = dispatcher.dispatch(desc_full_exact)
|
||||
if params["cudagraph_mode"] == "FULL":
|
||||
if cudagraph_mode_str == "FULL":
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_full_exact
|
||||
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_full_exact
|
||||
else:
|
||||
@ -111,15 +94,13 @@ class TestCudagraphDispatcher:
|
||||
# 2. uniform decode batch, size in cudagraph size list
|
||||
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
|
||||
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
|
||||
if params["cudagraph_mode"] == "FULL":
|
||||
if cudagraph_mode_str == "FULL":
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_uniform_exact.non_uniform
|
||||
elif params["cudagraph_mode"] in [
|
||||
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
|
||||
]:
|
||||
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_uniform_exact
|
||||
elif params["cudagraph_mode"] == "PIECEWISE":
|
||||
elif cudagraph_mode_str == "PIECEWISE":
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_uniform_exact.non_uniform
|
||||
else:
|
||||
@ -131,6 +112,16 @@ class TestCudagraphDispatcher:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
assert key is None
|
||||
|
||||
# 4. Cascade attention should have a fall back mode
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||
rt_mode, key = dispatcher.dispatch(desc_full_exact,
|
||||
use_cascade_attn=True)
|
||||
if "PIECEWISE" in cudagraph_mode_str: # string contains check
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_full_exact.non_uniform
|
||||
else:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
class TestCUDAGraphWrapper:
|
||||
|
||||
@ -4,12 +4,11 @@ import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import wait_for_gpu_memory_to_clear
|
||||
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
@ -34,74 +33,6 @@ def temporary_environ(env_vars):
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict
|
||||
specific_gpu_arch: Optional[tuple] = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA":
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
}
|
||||
|
||||
# test attention backend and cudagraph_mode combo
|
||||
# (backend_name, cudagraph_mode, supported)
|
||||
combo_cases_1 = [
|
||||
@ -114,9 +45,10 @@ combo_cases_1 = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("combo_case", combo_cases_1)
|
||||
def test_backend_and_cudagraph_mode_combo(combo_case):
|
||||
backend_name, cudagraph_mode, supported = combo_case
|
||||
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported",
|
||||
combo_cases_1)
|
||||
def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode,
|
||||
supported):
|
||||
if backend_name == "FlashInfer":
|
||||
try:
|
||||
import flashinfer # noqa: F401
|
||||
@ -142,7 +74,7 @@ def test_backend_and_cudagraph_mode_combo(combo_case):
|
||||
compilation_config=CompilationConfig(
|
||||
level=3, cudagraph_mode=cudagraph_mode))
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
|
||||
# when above code raises, `llm` may be undefined, so we need to catch that
|
||||
try:
|
||||
llm = weakref.proxy(llm)
|
||||
del llm
|
||||
@ -173,7 +105,8 @@ combo_cases_2 = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("combo_case", combo_cases_2)
|
||||
@pytest.mark.parametrize("backend_name,cudagraph_mode,compilation_level,"\
|
||||
"supported", combo_cases_2)
|
||||
def test_cudagraph_compilation_combo(combo_case):
|
||||
backend_name, cudagraph_mode, compilation_level, supported\
|
||||
= combo_case
|
||||
@ -192,6 +125,7 @@ def test_cudagraph_compilation_combo(combo_case):
|
||||
compilation_config=CompilationConfig(
|
||||
level=compilation_level, cudagraph_mode=cudagraph_mode))
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
# when above code raises, `llm` may be undefined, so we need to catch that
|
||||
try:
|
||||
llm = weakref.proxy(llm)
|
||||
del llm
|
||||
|
||||
@ -340,15 +340,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None)
|
||||
# Lazy import here to avoid circular import
|
||||
from .cuda_piecewise_backend import PiecewiseBackend
|
||||
from .piecewise_backend import PiecewiseBackend
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
submod, self.vllm_config, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
||||
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and
|
||||
if (self.compilation_config.cudagraph_mode.\
|
||||
has_piecewise_cudagraphs() and
|
||||
not self.compilation_config.use_inductor_graph_partition):
|
||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||
# the whole subgraph with a static graph wrapper.
|
||||
|
||||
@ -336,7 +336,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||
|
||||
@ -365,7 +365,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
|
||||
yield
|
||||
|
||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
|
||||
|
||||
@ -459,15 +459,22 @@ class VllmConfig:
|
||||
"to True to enable.")
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# final check of cudagraph mode after platform-specific update
|
||||
# Do this after all the updates to compilation_config.level
|
||||
if envs.VLLM_USE_V1 and \
|
||||
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
# final check of cudagraph mode after all possible updates
|
||||
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\
|
||||
and self.model_config is not None and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.info("CUDAGraphMode.FULL is not supported with "
|
||||
"cascade attention currently. Disabling cascade"
|
||||
"attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
not self.model_config.disable_cascade_attn and\
|
||||
not self.compilation_config.cudagraph_mode.\
|
||||
has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"No piecewise cudagraph for executing cascade attention."
|
||||
" Will fall back to eager execution if a batch runs "
|
||||
"into cascade attentions")
|
||||
|
||||
if self.compilation_config.cudagraph_mode\
|
||||
.requires_piecewise_compilation():
|
||||
@ -477,6 +484,12 @@ class VllmConfig:
|
||||
"when cudagraph_mode piecewise cudagraphs is used, "\
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
|
||||
# final migrate the deprecated flags
|
||||
self.compilation_config.use_cudagraph = self.compilation_config.\
|
||||
cudagraph_mode!= CUDAGraphMode.NONE
|
||||
self.compilation_config.full_cuda_graph = self.compilation_config.\
|
||||
cudagraph_mode.has_full_cudagraphs()
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
assert a2a_backend in \
|
||||
@ -487,14 +500,14 @@ class VllmConfig:
|
||||
"variable to deepep_low_latency or deepep_high_throughput and "\
|
||||
"install the DeepEP kernels."
|
||||
|
||||
if not self.model_config.disable_cascade_attn:
|
||||
self.model_config.disable_cascade_attn = True
|
||||
logger.warning_once(
|
||||
"Disabling cascade attention when DBO is enabled.")
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
# Do this after all the updates to compilation_config.level
|
||||
if envs.VLLM_USE_V1 and \
|
||||
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
if (envs.VLLM_USE_V1
|
||||
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
||||
# logger should only print warning message for hybrid models. As we
|
||||
|
||||
@ -61,9 +61,17 @@ class CUDAGraphMode(enum.Enum):
|
||||
def has_full_cudagraphs(self) -> bool:
|
||||
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
|
||||
|
||||
def has_piecewise_cudagraphs(self) -> bool:
|
||||
return self.requires_piecewise_compilation()
|
||||
|
||||
def separate_routine(self) -> bool:
|
||||
return isinstance(self.value, tuple)
|
||||
|
||||
def valid_runtime_modes(self) -> bool:
|
||||
return self in [
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
@ -269,7 +277,8 @@ class CompilationConfig:
|
||||
Note that this is orthogonal to the cudagraph capture logic
|
||||
outside of compilation.
|
||||
Warning: This flag is deprecated and will be removed in the next major or
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE
|
||||
instead.
|
||||
"""
|
||||
cudagraph_num_of_warmups: int = 0
|
||||
"""Number of warmup runs for cudagraph.
|
||||
@ -294,7 +303,8 @@ class CompilationConfig:
|
||||
flag cannot be used together with splitting_ops. This may provide
|
||||
performance benefits for smaller models.
|
||||
Warning: This flag is deprecated and will be removed in the next major or
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
|
||||
FULL_AND_PIECEWISE instead.
|
||||
"""
|
||||
|
||||
use_inductor_graph_partition: bool = False
|
||||
@ -464,7 +474,8 @@ class CompilationConfig:
|
||||
if not self.use_cudagraph:
|
||||
logger.warning("use_cudagraph is deprecated, use "
|
||||
"cudagraph_mode=NONE instead.")
|
||||
if self.cudagraph_mode is not None:
|
||||
if self.cudagraph_mode is not None and \
|
||||
self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
raise ValueError(
|
||||
"use_cudagraph and cudagraph_mode are mutually"
|
||||
" exclusive, prefer cudagraph_mode since "
|
||||
@ -473,7 +484,8 @@ class CompilationConfig:
|
||||
if self.full_cuda_graph:
|
||||
logger.warning("full_cuda_graph is deprecated, use "
|
||||
"cudagraph_mode=FULL instead.")
|
||||
if self.cudagraph_mode is not None:
|
||||
if self.cudagraph_mode is not None and \
|
||||
not self.cudagraph_mode.has_full_cudagraphs():
|
||||
raise ValueError("full_cuda_graph and cudagraph_mode are "
|
||||
"mutually exclusive, prefer cudagraph_mode "
|
||||
"since full_cuda_graph is deprecated.")
|
||||
@ -570,48 +582,75 @@ class CompilationConfig:
|
||||
"set_splitting_ops_for_v1 should only be called when "
|
||||
"level is CompilationLevel.PIECEWISE")
|
||||
|
||||
if self.use_inductor_graph_partition:
|
||||
self.set_splitting_ops_for_inductor_graph_partition()
|
||||
return
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
# here use_inductor_graph_partition is False
|
||||
self.set_splitting_ops_for_attn_fusion()
|
||||
return
|
||||
|
||||
if self.splitting_ops is None:
|
||||
# NOTE: When using full cudagraph, instead of setting an empty
|
||||
# list and capture the full cudagraph inside the flattened fx
|
||||
# graph, we keep the piecewise fx graph structure but capture
|
||||
# the full cudagraph outside the fx graph. This reduces some
|
||||
# cpu overhead when the runtime batch_size is not cudagraph
|
||||
# captured. see https://github.com/vllm-project/vllm/pull/20059
|
||||
# for details. Make a copy to avoid mutating the class-level
|
||||
# list via reference.
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
elif len(self.splitting_ops) == 0:
|
||||
logger.warning_once(
|
||||
"Using piecewise compilation with empty splitting_ops")
|
||||
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
logger.warning_once(
|
||||
"Piecewise compilation with empty splitting_ops do not" \
|
||||
"contains piecewise cudagraph. Setting cudagraph_"
|
||||
"mode to NONE. Hint: If you are using attention backends "
|
||||
"that support cudagraph, consider manually setting "
|
||||
"cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
|
||||
"full cudagraphs.")
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
||||
logger.warning_once(
|
||||
"Piecewise compilation with empty splitting_ops do not "
|
||||
"contains piecewise cudagraph. Setting cudagraph_mode "
|
||||
"to FULL.")
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
self.splitting_ops = []
|
||||
|
||||
def set_splitting_ops_for_inductor_graph_partition(self):
|
||||
assert self.use_inductor_graph_partition
|
||||
use_inductor_graph_partition_msg = (
|
||||
"When use_inductor_graph_partition=True, splitting_ops "
|
||||
"are ignored and set to an empty list. Instead, "
|
||||
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
|
||||
"used to annotate custom ops for graph partition.")
|
||||
|
||||
if self.splitting_ops is None:
|
||||
if self.use_inductor_graph_partition:
|
||||
# When using inductor graph partition, we set splitting_ops
|
||||
# to be empty and rely on torch._C.Tag.cudagraph_unsafe to
|
||||
# annotate custom ops as splitting ops.
|
||||
logger.warning_once(use_inductor_graph_partition_msg)
|
||||
self.splitting_ops = []
|
||||
else:
|
||||
# NOTE: When using full cudagraph, instead of setting an empty
|
||||
# list and capture the full cudagraph inside the flattened fx
|
||||
# graph, we keep the piecewise fx graph structure but capture
|
||||
# the full cudagraph outside the fx graph. This reduces some
|
||||
# cpu overhead when the runtime batch_size is not cudagraph
|
||||
# captured. see https://github.com/vllm-project/vllm/pull/20059
|
||||
# for details. make a copy to avoid mutating the class-level
|
||||
# list via reference.
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
elif len(self.splitting_ops) == 0:
|
||||
logger.warning_once(
|
||||
"Using piecewise compilation with empty "
|
||||
"splitting_ops and use_inductor_graph_partition"
|
||||
f"={self.use_inductor_graph_partition}.")
|
||||
if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
and not self.use_inductor_graph_partition):
|
||||
logger.warning_once(
|
||||
"When compilation level is piecewise with empty "
|
||||
"splitting_ops, PIECEWISE cudagraph_mode will be "
|
||||
"treated as FULL cudagraph_mode. Please ensure you are "
|
||||
"using attention backends that support cudagraph or set "
|
||||
"cudagraph_mode to NONE explicitly if encountering "
|
||||
"any problems.")
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
self.splitting_ops = []
|
||||
elif self.use_inductor_graph_partition:
|
||||
if self.splitting_ops is not None and \
|
||||
len(self.splitting_ops) > 0:
|
||||
logger.warning_once(use_inductor_graph_partition_msg)
|
||||
self.splitting_ops = []
|
||||
|
||||
def set_splitting_ops_for_attn_fusion(self):
|
||||
assert self.pass_config.enable_attn_fusion
|
||||
if self.splitting_ops is None:
|
||||
self.splitting_ops = []
|
||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"enable_attn_fusion is incompatible with piecewise "
|
||||
"cudagraph when use_inductor_graph_partition is off."
|
||||
"In this case, splitting_ops will be set to empty "
|
||||
"list, and cudagraph_mode will be set to FULL. "
|
||||
"Please ensure you are using attention backends that "
|
||||
"support cudagraph or set cudagraph_mode to NONE "
|
||||
"explicitly if encountering any problems.")
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
assert not self.splitting_ops_contain_attention(), (
|
||||
"attention ops should not be in splitting_ops "
|
||||
"when enable_attn_fusion is True")
|
||||
|
||||
def splitting_ops_contain_attention(self) -> bool:
|
||||
return self.splitting_ops is not None and all(
|
||||
|
||||
@ -246,8 +246,7 @@ class ForwardContext:
|
||||
ubatch_slices: Optional[UBatchSlices] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.cudagraph_runtime_mode in [
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
||||
assert self.cudagraph_runtime_mode.valid_runtime_modes(), \
|
||||
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
|
||||
|
||||
|
||||
|
||||
@ -22,10 +22,10 @@ class CudagraphDispatcher:
|
||||
|
||||
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
|
||||
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
|
||||
based on the input key. After dispatching (communicate via forward context),
|
||||
the cudagraph wrappers will trust the dispatch key to do either capturing
|
||||
or replaying (if mode matched), or pass through to the underlying runnable
|
||||
without cudagraph (if mode no match or mode is NONE).
|
||||
based on the input key. After dispatching (communicated via forward
|
||||
context), the cudagraph wrappers will trust the dispatch key to either
|
||||
capture or replay (if the mode matches), or pass through to the underlying
|
||||
runnable without cudagraph (if the mode does not match or mode is NONE).
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
@ -57,19 +57,15 @@ class CudagraphDispatcher:
|
||||
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
|
||||
batch_descriptor: BatchDescriptor):
|
||||
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
||||
f"Invalid cudagraph runtime mode: {runtime_mode}"
|
||||
f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
|
||||
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
|
||||
uniform_decode_query_len: int):
|
||||
# This should be called only after attention backend is initialized.
|
||||
|
||||
# Note: we create all valid keys possible for cudagraph but do not
|
||||
# guarantee all keys would be used. For example, we create keys for
|
||||
# piecewise cudagraphs when it is piecewise compilation, which is always
|
||||
# valid, but for attention backend support unified routine, we may not
|
||||
# trigger capturing/replaying the piecewise cudagraphs depending on
|
||||
# CompilationConfig.cudagraph_mode. In addition, if we allow lazy
|
||||
# Note: we create all valid keys for cudagraph here but do not
|
||||
# guarantee all keys would be used. For example, if we allow lazy
|
||||
# capturing in future PR, some keys may never be triggered.
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
for bs in self.compilation_config.cudagraph_capture_sizes:
|
||||
@ -94,10 +90,13 @@ class CudagraphDispatcher:
|
||||
self.keys_initialized = True
|
||||
|
||||
def dispatch(
|
||||
self, batch_descriptor: BatchDescriptor
|
||||
self,
|
||||
batch_descriptor: BatchDescriptor,
|
||||
use_cascade_attn: bool = False
|
||||
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
|
||||
"""
|
||||
Given a batch descriptor, dispatch to a cudagraph mode.
|
||||
Given conditions(e.g.,batch descriptor and if using cascade attention),
|
||||
dispatch to a cudagraph runtime mode and the valid batch descriptor.
|
||||
A new batch descriptor is returned as we might dispatch a uniform batch
|
||||
to a graph that supports a more general batch (uniform to non-uniform).
|
||||
"""
|
||||
@ -107,14 +106,16 @@ class CudagraphDispatcher:
|
||||
"initialized. No cudagraph will be used.")
|
||||
return CUDAGraphMode.NONE, None
|
||||
|
||||
# check if key exists for full cudagraph
|
||||
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, batch_descriptor
|
||||
|
||||
# otherwise, check if non-uniform key exists
|
||||
non_uniform_key = batch_descriptor.non_uniform
|
||||
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, non_uniform_key
|
||||
# if a batch use cascade attention, bypass checking full cudagraphs
|
||||
if not use_cascade_attn:
|
||||
# check if key exists for full cudagraph
|
||||
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, batch_descriptor
|
||||
|
||||
# otherwise, check if non-uniform key exists
|
||||
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, non_uniform_key
|
||||
|
||||
# also check if non-uniform key exists for more "general"
|
||||
# piecewise cudagraph
|
||||
|
||||
@ -923,11 +923,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
||||
Optional[SpecDecodeMetadata], np.ndarray,
|
||||
Optional[CommonAttentionMetadata], int, Optional[UBatchSlices],
|
||||
Optional[torch.Tensor]]:
|
||||
Optional[torch.Tensor], bool]:
|
||||
"""
|
||||
:return: tuple[
|
||||
attn_metadata: layer-to-attention_metadata mapping,
|
||||
logits_indices, spec_decode_metadata
|
||||
logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
||||
max_num_scheduled_tokens, use_cascade_attn
|
||||
]
|
||||
"""
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@ -1135,6 +1137,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata: PerLayerAttnMetadata = {}
|
||||
if ubatch_slices is not None:
|
||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||
use_cascade_attn = False
|
||||
|
||||
# Used in the below loop.
|
||||
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||
@ -1251,9 +1254,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args)
|
||||
use_cascade_attn |= getattr(attn_metadata_i, "use_cascade",
|
||||
False)
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
# disable cascade attention when DBO
|
||||
if ubatch_slices is not None:
|
||||
use_cascade_attn = False
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
@ -1261,7 +1270,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return (attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
||||
max_num_scheduled_tokens, ubatch_slices,
|
||||
num_tokens_after_padding)
|
||||
num_tokens_after_padding, use_cascade_attn)
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@ -2251,8 +2260,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
max_query_len, ubatch_slices, num_tokens_after_padding
|
||||
) = self._prepare_inputs(scheduler_output)
|
||||
max_query_len, ubatch_slices, num_tokens_after_padding,
|
||||
use_cascade_attn) = self._prepare_inputs(scheduler_output)
|
||||
|
||||
(
|
||||
num_scheduled_tokens,
|
||||
@ -2273,7 +2282,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode)
|
||||
cudagraph_runtime_mode, batch_descriptor = \
|
||||
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
||||
self.cudagraph_dispatcher.dispatch(batch_descriptor,
|
||||
use_cascade_attn)
|
||||
|
||||
# This is currently to get around the assert in the DPMetadata
|
||||
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
||||
@ -2701,16 +2711,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"Cannot reload weights before model is loaded."
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
logger.info("Reloading weights inplace...")
|
||||
model = self.get_model()
|
||||
model_loader.load_weights(model, model_config=self.model_config)
|
||||
model_loader.load_weights(self.get_model(),
|
||||
model_config=self.model_config)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: "TensorizerConfig",
|
||||
) -> None:
|
||||
model = self.get_model()
|
||||
TensorizerLoader.save_model(
|
||||
model,
|
||||
self.get_model(),
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
@ -2926,9 +2935,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
"""
|
||||
assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
}
|
||||
assert cudagraph_runtime_mode is None or \
|
||||
cudagraph_runtime_mode.valid_runtime_modes()
|
||||
|
||||
# If cudagraph_mode.decode_mode() == FULL and
|
||||
# cudagraph_mode.separate_routine(). This means that we are using
|
||||
@ -3113,7 +3121,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# filter out the valid batch descriptor
|
||||
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
|
||||
BatchDescriptor(num_tokens=num_tokens,
|
||||
uniform_decode=uniform_decode))
|
||||
uniform_decode=uniform_decode)) \
|
||||
if not is_profile else (CUDAGraphMode.NONE, None)
|
||||
if cudagraph_runtime_mode is not None:
|
||||
# we allow forcing NONE when the dispatcher disagrees to support
|
||||
# warm ups for cudagraph capture
|
||||
@ -3453,8 +3462,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
uniform_decode: bool):
|
||||
assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
|
||||
cudagraph_runtime_mode in [CUDAGraphMode.FULL,
|
||||
CUDAGraphMode.PIECEWISE]
|
||||
cudagraph_runtime_mode.valid_runtime_modes(), \
|
||||
f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
|
||||
|
||||
# Only rank 0 should print progress bar during capture
|
||||
if is_global_first_rank():
|
||||
@ -3585,6 +3594,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.calculate_reorder_batch_threshold()
|
||||
|
||||
def initialize_cudagraph_capture(self) -> None:
|
||||
"""
|
||||
Resolve the cudagraph_mode when there are multiple attention
|
||||
backends with potential conflicting CUDA graph support.
|
||||
Then initialize the cudagraph_dispatcher based on the resolved
|
||||
cudagraph_mode.
|
||||
"""
|
||||
min_cg_support = AttentionCGSupport.ALWAYS
|
||||
min_cg_builder_name = None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user