mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 19:55:35 +08:00
[CI][DCP][Perf] reduce DCP CI execution time (#29858)
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
parent
b286a311c2
commit
46cbbca05c
@ -16,16 +16,35 @@ from typing import Literal, NamedTuple
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
|
||||
from tests.utils import RemoteOpenAIServer, create_new_process_for_each_test
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
|
||||
logger = init_logger("test_context_parallel")
|
||||
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
CP_TEST_MODELS = [
|
||||
# TODO support other models
|
||||
# [LANGUAGE GENERATION]
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
"Qwen/Qwen2.5-1.5B-Instruct",
|
||||
]
|
||||
|
||||
# GSM8K eval configuration
|
||||
NUM_QUESTIONS = 256 # Fast eval for CI
|
||||
NUM_SHOTS = 5 # Few-shot examples
|
||||
# tp accuracy with 2% buffer
|
||||
MIN_ACCURACY = {
|
||||
# .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.64,
|
||||
# .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
|
||||
"Qwen/Qwen2.5-1.5B-Instruct": 0.52,
|
||||
}
|
||||
|
||||
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple):
|
||||
|
||||
class CPTestOptions(NamedTuple):
|
||||
multi_node_only: bool
|
||||
load_format: str | None = None
|
||||
attn_backend: str | None = None
|
||||
|
||||
|
||||
@ -54,17 +72,20 @@ class CPTestSettings:
|
||||
*,
|
||||
tp_base: int = 4,
|
||||
pp_base: int = 1,
|
||||
dcp_base: int = 1,
|
||||
dcp_multipliers: list[float] | None = None,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
multi_node_only: bool = False,
|
||||
runner: RunnerOption = "auto",
|
||||
load_format: str | None = None,
|
||||
attn_backend: str | None = None,
|
||||
):
|
||||
parallel_setups = []
|
||||
if dcp_multipliers is None:
|
||||
dcp_multipliers = [
|
||||
0.5,
|
||||
]
|
||||
for eager_mode_val in [False]:
|
||||
for pp_multiplier in [1]:
|
||||
for dcp_multiplier in [0.5, 1]:
|
||||
for dcp_multiplier in dcp_multipliers:
|
||||
for chunked_prefill_val in [True]:
|
||||
parallel_setups.append(
|
||||
ParallelSetup(
|
||||
@ -82,7 +103,6 @@ class CPTestSettings:
|
||||
runner=runner,
|
||||
test_options=CPTestOptions(
|
||||
multi_node_only=multi_node_only,
|
||||
load_format=load_format,
|
||||
attn_backend=attn_backend,
|
||||
),
|
||||
)
|
||||
@ -101,7 +121,24 @@ class CPTestSettings:
|
||||
)
|
||||
|
||||
|
||||
def _compare_cp_with_tp(
|
||||
CP_TEXT_GENERATION_MODELS = {
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||
CPTestSettings.detailed(
|
||||
dcp_multipliers=[0.5, 1], cp_kv_cache_interleave_size=64
|
||||
),
|
||||
],
|
||||
"Qwen/Qwen2.5-1.5B-Instruct": [
|
||||
CPTestSettings.detailed(
|
||||
cp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN"
|
||||
),
|
||||
CPTestSettings.detailed(
|
||||
cp_kv_cache_interleave_size=16, attn_backend="FLASHINFER"
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _test_cp_gsm8k(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
@ -121,7 +158,7 @@ def _compare_cp_with_tp(
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
|
||||
multi_node_only, load_format, attn_backend = test_options
|
||||
multi_node_only, attn_backend = test_options
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
@ -130,21 +167,6 @@ def _compare_cp_with_tp(
|
||||
tokenizer_mode = model_info.tokenizer_mode
|
||||
hf_overrides = model_info.hf_overrides
|
||||
|
||||
if load_format == "dummy":
|
||||
# Avoid OOM
|
||||
text_overrides = {
|
||||
"num_hidden_layers": 4,
|
||||
"hidden_size": 512,
|
||||
"intermediate_size": 800,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 1,
|
||||
}
|
||||
|
||||
if is_multimodal:
|
||||
hf_overrides.update({"text_config": text_overrides})
|
||||
else:
|
||||
hf_overrides.update(text_overrides)
|
||||
else:
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
@ -157,39 +179,30 @@ def _compare_cp_with_tp(
|
||||
if multi_node_only and not VLLM_MULTI_NODE:
|
||||
pytest.skip("Not in multi-node setting")
|
||||
|
||||
common_args = [
|
||||
server_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"4096",
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
"64",
|
||||
]
|
||||
if chunked_prefill:
|
||||
common_args.append("--enable-chunked-prefill")
|
||||
server_args.append("--enable-chunked-prefill")
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
server_args.append("--enforce-eager")
|
||||
if runner != "auto":
|
||||
common_args.extend(["--runner", runner])
|
||||
server_args.extend(["--runner", runner])
|
||||
if trust_remote_code:
|
||||
common_args.append("--trust-remote-code")
|
||||
server_args.append("--trust-remote-code")
|
||||
if tokenizer_mode:
|
||||
common_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||
if load_format:
|
||||
common_args.extend(["--load-format", load_format])
|
||||
server_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
server_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
|
||||
if not attn_backend:
|
||||
cp_env = tp_env = {}
|
||||
else:
|
||||
cp_env = tp_env = {
|
||||
"VLLM_ATTENTION_BACKEND": attn_backend,
|
||||
}
|
||||
|
||||
cp_args = [
|
||||
*common_args,
|
||||
server_args.extend(
|
||||
[
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--pipeline-parallel-size",
|
||||
@ -201,46 +214,35 @@ def _compare_cp_with_tp(
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
]
|
||||
|
||||
tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--pipeline-parallel-size",
|
||||
str(pp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
]
|
||||
|
||||
compare_two_settings(
|
||||
model_id,
|
||||
cp_args,
|
||||
tp_args,
|
||||
cp_env,
|
||||
tp_env,
|
||||
method=method,
|
||||
max_wait_seconds=720,
|
||||
)
|
||||
|
||||
server_env = {}
|
||||
if attn_backend:
|
||||
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend
|
||||
|
||||
CP_TEXT_GENERATION_MODELS = {
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||
CPTestSettings.detailed(),
|
||||
CPTestSettings.detailed(tp_base=2),
|
||||
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
|
||||
],
|
||||
"bigcode/gpt_bigcode-santacoder": [
|
||||
CPTestSettings.detailed(),
|
||||
CPTestSettings.detailed(tp_base=2),
|
||||
],
|
||||
}
|
||||
with RemoteOpenAIServer(
|
||||
model_id,
|
||||
server_args,
|
||||
env_dict=server_env,
|
||||
max_wait_seconds=720,
|
||||
) as remote_server:
|
||||
host = f"http://{remote_server.host}"
|
||||
port = remote_server.port
|
||||
|
||||
CP_TEST_MODELS = [
|
||||
# TODO support other models
|
||||
# [LANGUAGE GENERATION]
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
"bigcode/gpt_bigcode-santacoder",
|
||||
]
|
||||
# Run GSM8K evaluation
|
||||
results = evaluate_gsm8k(
|
||||
num_questions=NUM_QUESTIONS,
|
||||
num_shots=NUM_SHOTS,
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
# Validate accuracy is reasonable
|
||||
accuracy = results["accuracy"]
|
||||
min_accuracy = MIN_ACCURACY[model_id]
|
||||
assert accuracy >= min_accuracy, (
|
||||
f"TP+DCP accuracy too low: {accuracy:.3f} < {min_accuracy:.3f}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -274,12 +276,12 @@ def test_cp_generation(
|
||||
):
|
||||
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
|
||||
if (
|
||||
model_id == "bigcode/gpt_bigcode-santacoder"
|
||||
model_id == "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
and torch.cuda.get_device_capability() != (9, 0)
|
||||
):
|
||||
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")
|
||||
|
||||
_compare_cp_with_tp(
|
||||
_test_cp_gsm8k(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
|
||||
@ -416,7 +416,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"Qwen2ForCausalLM": _HfExamplesInfo(
|
||||
"Qwen/Qwen2-0.5B-Instruct", extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}
|
||||
"Qwen/Qwen2-0.5B-Instruct",
|
||||
extras={
|
||||
"2.5": "Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"2.5-1.5B": "Qwen/Qwen2.5-1.5B-Instruct",
|
||||
},
|
||||
),
|
||||
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
|
||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user