diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index ab325e096692..8eb5ca9461c7 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -14,36 +14,29 @@ from ..utils import compare_two_settings, fork_new_process_for_each_test VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.parametrize( - ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " - "MODEL_NAME, DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL"), [ - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - ]) +@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " + "MODEL_NAME, DIST_BACKEND"), + [ + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + ]) def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, - DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL): + DIST_BACKEND): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") + USE_RAY_ADAG_NCCL = 0 + USE_RAY_ADAG = 0 + pp_args = [ # use half precision for speed and memory savings in CI environment "--dtype",