[CI][DCP][Perf] reduce DCP CI execution time (#29858)

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu 2025-12-05 01:28:21 +08:00 committed by GitHub
parent b286a311c2
commit 46cbbca05c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 100 additions and 94 deletions

View File

@ -16,16 +16,35 @@ from typing import Literal, NamedTuple
import pytest import pytest
import torch import torch
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
from tests.utils import RemoteOpenAIServer, create_new_process_for_each_test
from vllm.config.model import RunnerOption from vllm.config.model import RunnerOption
from vllm.logger import init_logger from vllm.logger import init_logger
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test
logger = init_logger("test_context_parallel") logger = init_logger("test_context_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
CP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"deepseek-ai/DeepSeek-V2-Lite-Chat",
"Qwen/Qwen2.5-1.5B-Instruct",
]
# GSM8K eval configuration
NUM_QUESTIONS = 256 # Fast eval for CI
NUM_SHOTS = 5 # Few-shot examples
# tp accuracy with 2% buffer
MIN_ACCURACY = {
# .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.64,
# .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
"Qwen/Qwen2.5-1.5B-Instruct": 0.52,
}
class ParallelSetup(NamedTuple): class ParallelSetup(NamedTuple):
tp_size: int tp_size: int
@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple):
class CPTestOptions(NamedTuple): class CPTestOptions(NamedTuple):
multi_node_only: bool multi_node_only: bool
load_format: str | None = None
attn_backend: str | None = None attn_backend: str | None = None
@ -54,17 +72,20 @@ class CPTestSettings:
*, *,
tp_base: int = 4, tp_base: int = 4,
pp_base: int = 1, pp_base: int = 1,
dcp_base: int = 1, dcp_multipliers: list[float] | None = None,
cp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False, multi_node_only: bool = False,
runner: RunnerOption = "auto", runner: RunnerOption = "auto",
load_format: str | None = None,
attn_backend: str | None = None, attn_backend: str | None = None,
): ):
parallel_setups = [] parallel_setups = []
if dcp_multipliers is None:
dcp_multipliers = [
0.5,
]
for eager_mode_val in [False]: for eager_mode_val in [False]:
for pp_multiplier in [1]: for pp_multiplier in [1]:
for dcp_multiplier in [0.5, 1]: for dcp_multiplier in dcp_multipliers:
for chunked_prefill_val in [True]: for chunked_prefill_val in [True]:
parallel_setups.append( parallel_setups.append(
ParallelSetup( ParallelSetup(
@ -82,7 +103,6 @@ class CPTestSettings:
runner=runner, runner=runner,
test_options=CPTestOptions( test_options=CPTestOptions(
multi_node_only=multi_node_only, multi_node_only=multi_node_only,
load_format=load_format,
attn_backend=attn_backend, attn_backend=attn_backend,
), ),
) )
@ -101,7 +121,24 @@ class CPTestSettings:
) )
def _compare_cp_with_tp( CP_TEXT_GENERATION_MODELS = {
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(
dcp_multipliers=[0.5, 1], cp_kv_cache_interleave_size=64
),
],
"Qwen/Qwen2.5-1.5B-Instruct": [
CPTestSettings.detailed(
cp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN"
),
CPTestSettings.detailed(
cp_kv_cache_interleave_size=16, attn_backend="FLASHINFER"
),
],
}
def _test_cp_gsm8k(
model_id: str, model_id: str,
parallel_setup: ParallelSetup, parallel_setup: ParallelSetup,
distributed_backend: str, distributed_backend: str,
@ -121,7 +158,7 @@ def _compare_cp_with_tp(
chunked_prefill, chunked_prefill,
) = parallel_setup ) = parallel_setup
multi_node_only, load_format, attn_backend = test_options multi_node_only, attn_backend = test_options
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
@ -130,22 +167,7 @@ def _compare_cp_with_tp(
tokenizer_mode = model_info.tokenizer_mode tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides hf_overrides = model_info.hf_overrides
if load_format == "dummy": model_info.check_available_online(on_fail="skip")
# Avoid OOM
text_overrides = {
"num_hidden_layers": 4,
"hidden_size": 512,
"intermediate_size": 800,
"num_attention_heads": 4,
"num_key_value_heads": 1,
}
if is_multimodal:
hf_overrides.update({"text_config": text_overrides})
else:
hf_overrides.update(text_overrides)
else:
model_info.check_available_online(on_fail="skip")
if num_gpus_available < tp_size * pp_size: if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
@ -157,90 +179,70 @@ def _compare_cp_with_tp(
if multi_node_only and not VLLM_MULTI_NODE: if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting") pytest.skip("Not in multi-node setting")
common_args = [ server_args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", "bfloat16",
"--max-model-len", "--max-model-len",
"2048", "4096",
"--max-num-seqs", "--max-num-seqs",
"8", "64",
] ]
if chunked_prefill: if chunked_prefill:
common_args.append("--enable-chunked-prefill") server_args.append("--enable-chunked-prefill")
if eager_mode: if eager_mode:
common_args.append("--enforce-eager") server_args.append("--enforce-eager")
if runner != "auto": if runner != "auto":
common_args.extend(["--runner", runner]) server_args.extend(["--runner", runner])
if trust_remote_code: if trust_remote_code:
common_args.append("--trust-remote-code") server_args.append("--trust-remote-code")
if tokenizer_mode: if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode]) server_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides: if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) server_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if not attn_backend: server_args.extend(
cp_env = tp_env = {} [
else: "--tensor-parallel-size",
cp_env = tp_env = { str(tp_size),
"VLLM_ATTENTION_BACKEND": attn_backend, "--pipeline-parallel-size",
} str(pp_size),
"--decode-context-parallel-size",
cp_args = [ str(dcp_size),
*common_args, "--dcp-kv-cache-interleave-size",
"--tensor-parallel-size", str(cp_kv_cache_interleave_size),
str(tp_size), "--distributed-executor-backend",
"--pipeline-parallel-size", distributed_backend,
str(pp_size), ]
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
str(cp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--distributed-executor-backend",
distributed_backend,
]
compare_two_settings(
model_id,
cp_args,
tp_args,
cp_env,
tp_env,
method=method,
max_wait_seconds=720,
) )
server_env = {}
if attn_backend:
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend
CP_TEXT_GENERATION_MODELS = { with RemoteOpenAIServer(
"deepseek-ai/DeepSeek-V2-Lite-Chat": [ model_id,
CPTestSettings.detailed(), server_args,
CPTestSettings.detailed(tp_base=2), env_dict=server_env,
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64), max_wait_seconds=720,
], ) as remote_server:
"bigcode/gpt_bigcode-santacoder": [ host = f"http://{remote_server.host}"
CPTestSettings.detailed(), port = remote_server.port
CPTestSettings.detailed(tp_base=2),
],
}
CP_TEST_MODELS = [ # Run GSM8K evaluation
# TODO support other models results = evaluate_gsm8k(
# [LANGUAGE GENERATION] num_questions=NUM_QUESTIONS,
"deepseek-ai/DeepSeek-V2-Lite-Chat", num_shots=NUM_SHOTS,
"bigcode/gpt_bigcode-santacoder", host=host,
] port=port,
)
# Validate accuracy is reasonable
accuracy = results["accuracy"]
min_accuracy = MIN_ACCURACY[model_id]
assert accuracy >= min_accuracy, (
f"TP+DCP accuracy too low: {accuracy:.3f} < {min_accuracy:.3f}"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -274,12 +276,12 @@ def test_cp_generation(
): ):
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher") pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
if ( if (
model_id == "bigcode/gpt_bigcode-santacoder" model_id == "Qwen/Qwen2.5-1.5B-Instruct"
and torch.cuda.get_device_capability() != (9, 0) and torch.cuda.get_device_capability() != (9, 0)
): ):
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0") pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")
_compare_cp_with_tp( _test_cp_gsm8k(
model_id, model_id,
parallel_setup, parallel_setup,
distributed_backend, distributed_backend,

View File

@ -416,7 +416,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
), ),
"Qwen2ForCausalLM": _HfExamplesInfo( "Qwen2ForCausalLM": _HfExamplesInfo(
"Qwen/Qwen2-0.5B-Instruct", extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"} "Qwen/Qwen2-0.5B-Instruct",
extras={
"2.5": "Qwen/Qwen2.5-0.5B-Instruct",
"2.5-1.5B": "Qwen/Qwen2.5-1.5B-Instruct",
},
), ),
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),