diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 7e4713b8aece0..3cb533dccd62c 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -16,16 +16,35 @@ from typing import Literal, NamedTuple import pytest import torch +from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k +from tests.utils import RemoteOpenAIServer, create_new_process_for_each_test from vllm.config.model import RunnerOption from vllm.logger import init_logger from ..models.registry import HF_EXAMPLE_MODELS -from ..utils import compare_two_settings, create_new_process_for_each_test logger = init_logger("test_context_parallel") VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" +CP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "deepseek-ai/DeepSeek-V2-Lite-Chat", + "Qwen/Qwen2.5-1.5B-Instruct", +] + +# GSM8K eval configuration +NUM_QUESTIONS = 256 # Fast eval for CI +NUM_SHOTS = 5 # Few-shot examples +# tp accuracy with 2% buffer +MIN_ACCURACY = { + # .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml + "deepseek-ai/DeepSeek-V2-Lite-Chat": 0.64, + # .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml + "Qwen/Qwen2.5-1.5B-Instruct": 0.52, +} + class ParallelSetup(NamedTuple): tp_size: int @@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple): class CPTestOptions(NamedTuple): multi_node_only: bool - load_format: str | None = None attn_backend: str | None = None @@ -54,17 +72,20 @@ class CPTestSettings: *, tp_base: int = 4, pp_base: int = 1, - dcp_base: int = 1, + dcp_multipliers: list[float] | None = None, cp_kv_cache_interleave_size: int = 1, multi_node_only: bool = False, runner: RunnerOption = "auto", - load_format: str | None = None, attn_backend: str | None = None, ): parallel_setups = [] + if dcp_multipliers is None: + dcp_multipliers = [ + 0.5, + ] for eager_mode_val in [False]: for pp_multiplier in [1]: - for dcp_multiplier in [0.5, 1]: + for dcp_multiplier in dcp_multipliers: for chunked_prefill_val in [True]: parallel_setups.append( ParallelSetup( @@ -82,7 +103,6 @@ class CPTestSettings: runner=runner, test_options=CPTestOptions( multi_node_only=multi_node_only, - load_format=load_format, attn_backend=attn_backend, ), ) @@ -101,7 +121,24 @@ class CPTestSettings: ) -def _compare_cp_with_tp( +CP_TEXT_GENERATION_MODELS = { + "deepseek-ai/DeepSeek-V2-Lite-Chat": [ + CPTestSettings.detailed( + dcp_multipliers=[0.5, 1], cp_kv_cache_interleave_size=64 + ), + ], + "Qwen/Qwen2.5-1.5B-Instruct": [ + CPTestSettings.detailed( + cp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN" + ), + CPTestSettings.detailed( + cp_kv_cache_interleave_size=16, attn_backend="FLASHINFER" + ), + ], +} + + +def _test_cp_gsm8k( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, @@ -121,7 +158,7 @@ def _compare_cp_with_tp( chunked_prefill, ) = parallel_setup - multi_node_only, load_format, attn_backend = test_options + multi_node_only, attn_backend = test_options model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_transformers_version(on_fail="skip") @@ -130,22 +167,7 @@ def _compare_cp_with_tp( tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides - if load_format == "dummy": - # Avoid OOM - text_overrides = { - "num_hidden_layers": 4, - "hidden_size": 512, - "intermediate_size": 800, - "num_attention_heads": 4, - "num_key_value_heads": 1, - } - - if is_multimodal: - hf_overrides.update({"text_config": text_overrides}) - else: - hf_overrides.update(text_overrides) - else: - model_info.check_available_online(on_fail="skip") + model_info.check_available_online(on_fail="skip") if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") @@ -157,90 +179,70 @@ def _compare_cp_with_tp( if multi_node_only and not VLLM_MULTI_NODE: pytest.skip("Not in multi-node setting") - common_args = [ + server_args = [ # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", "--max-model-len", - "2048", + "4096", "--max-num-seqs", - "8", + "64", ] if chunked_prefill: - common_args.append("--enable-chunked-prefill") + server_args.append("--enable-chunked-prefill") if eager_mode: - common_args.append("--enforce-eager") + server_args.append("--enforce-eager") if runner != "auto": - common_args.extend(["--runner", runner]) + server_args.extend(["--runner", runner]) if trust_remote_code: - common_args.append("--trust-remote-code") + server_args.append("--trust-remote-code") if tokenizer_mode: - common_args.extend(["--tokenizer-mode", tokenizer_mode]) - if load_format: - common_args.extend(["--load-format", load_format]) + server_args.extend(["--tokenizer-mode", tokenizer_mode]) if hf_overrides: - common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + server_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) - if not attn_backend: - cp_env = tp_env = {} - else: - cp_env = tp_env = { - "VLLM_ATTENTION_BACKEND": attn_backend, - } - - cp_args = [ - *common_args, - "--tensor-parallel-size", - str(tp_size), - "--pipeline-parallel-size", - str(pp_size), - "--decode-context-parallel-size", - str(dcp_size), - "--dcp-kv-cache-interleave-size", - str(cp_kv_cache_interleave_size), - "--distributed-executor-backend", - distributed_backend, - ] - - tp_args = [ - *common_args, - "--tensor-parallel-size", - str(tp_size), - "--pipeline-parallel-size", - str(pp_size), - "--distributed-executor-backend", - distributed_backend, - ] - - compare_two_settings( - model_id, - cp_args, - tp_args, - cp_env, - tp_env, - method=method, - max_wait_seconds=720, + server_args.extend( + [ + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + "--decode-context-parallel-size", + str(dcp_size), + "--dcp-kv-cache-interleave-size", + str(cp_kv_cache_interleave_size), + "--distributed-executor-backend", + distributed_backend, + ] ) + server_env = {} + if attn_backend: + server_env["VLLM_ATTENTION_BACKEND"] = attn_backend -CP_TEXT_GENERATION_MODELS = { - "deepseek-ai/DeepSeek-V2-Lite-Chat": [ - CPTestSettings.detailed(), - CPTestSettings.detailed(tp_base=2), - CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64), - ], - "bigcode/gpt_bigcode-santacoder": [ - CPTestSettings.detailed(), - CPTestSettings.detailed(tp_base=2), - ], -} + with RemoteOpenAIServer( + model_id, + server_args, + env_dict=server_env, + max_wait_seconds=720, + ) as remote_server: + host = f"http://{remote_server.host}" + port = remote_server.port -CP_TEST_MODELS = [ - # TODO support other models - # [LANGUAGE GENERATION] - "deepseek-ai/DeepSeek-V2-Lite-Chat", - "bigcode/gpt_bigcode-santacoder", -] + # Run GSM8K evaluation + results = evaluate_gsm8k( + num_questions=NUM_QUESTIONS, + num_shots=NUM_SHOTS, + host=host, + port=port, + ) + + # Validate accuracy is reasonable + accuracy = results["accuracy"] + min_accuracy = MIN_ACCURACY[model_id] + assert accuracy >= min_accuracy, ( + f"TP+DCP accuracy too low: {accuracy:.3f} < {min_accuracy:.3f}" + ) @pytest.mark.parametrize( @@ -274,12 +276,12 @@ def test_cp_generation( ): pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher") if ( - model_id == "bigcode/gpt_bigcode-santacoder" + model_id == "Qwen/Qwen2.5-1.5B-Instruct" and torch.cuda.get_device_capability() != (9, 0) ): pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0") - _compare_cp_with_tp( + _test_cp_gsm8k( model_id, parallel_setup, distributed_backend, diff --git a/tests/models/registry.py b/tests/models/registry.py index b9f9945eb5fb8..352abdd2da9a0 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -416,7 +416,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True, ), "Qwen2ForCausalLM": _HfExamplesInfo( - "Qwen/Qwen2-0.5B-Instruct", extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"} + "Qwen/Qwen2-0.5B-Instruct", + extras={ + "2.5": "Qwen/Qwen2.5-0.5B-Instruct", + "2.5-1.5B": "Qwen/Qwen2.5-1.5B-Instruct", + }, ), "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),