diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index f5ee469fb00a..31aa89828200 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -7,10 +7,10 @@ from vllm import LLM, SamplingParams from vllm.device_allocator.cumem import CuMemAllocator from vllm.utils import GiB_bytes -from ..utils import fork_new_process_for_each_test +from ..utils import create_new_process_for_each_test -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_python_error(): """ Test if Python error occurs when there's low-level @@ -36,7 +36,7 @@ def test_python_error(): allocator.wake_up() -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_basic_cumem(): # some tensors from default memory pool shape = (1024, 1024) @@ -69,7 +69,7 @@ def test_basic_cumem(): assert torch.allclose(output, torch.ones_like(output) * 3) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_cumem_with_cudagraph(): allocator = CuMemAllocator.get_instance() with allocator.use_memory_pool(): @@ -114,7 +114,7 @@ def test_cumem_with_cudagraph(): assert torch.allclose(y, x + 1) -@fork_new_process_for_each_test +@create_new_process_for_each_test() @pytest.mark.parametrize( "model, use_v1", [ diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index cf463f3e7525..3a45c35442ca 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -12,7 +12,7 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationLevel from vllm.platforms import current_platform -from ..utils import fork_new_process_for_each_test +from ..utils import create_new_process_for_each_test @pytest.fixture(params=None, name="model_info") @@ -78,7 +78,7 @@ def models_list_fixture(request): [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], ) @pytest.mark.parametrize("model_info", "", indirect=True) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_full_graph( monkeypatch: pytest.MonkeyPatch, model_info: tuple[str, dict[str, Any]], diff --git a/tests/distributed/test_expert_parallel.py b/tests/distributed/test_expert_parallel.py index 2e575f95d5f1..db8281617803 100644 --- a/tests/distributed/test_expert_parallel.py +++ b/tests/distributed/test_expert_parallel.py @@ -8,7 +8,7 @@ import pytest from vllm.config import TaskOption from vllm.logger import init_logger -from ..utils import compare_two_settings, fork_new_process_for_each_test +from ..utils import compare_two_settings, create_new_process_for_each_test logger = init_logger("test_expert_parallel") @@ -209,7 +209,7 @@ def _compare_tp( for params in settings.iter_params(model_name) ], ) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_ep( model_name: str, parallel_setup: ParallelSetup, diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 4d3306509c8f..1342f0da29d8 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -17,7 +17,7 @@ from vllm.config import TaskOption from vllm.logger import init_logger from ..models.registry import HF_EXAMPLE_MODELS -from ..utils import compare_two_settings, fork_new_process_for_each_test +from ..utils import compare_two_settings, create_new_process_for_each_test logger = init_logger("test_pipeline_parallel") @@ -402,7 +402,7 @@ def _compare_tp( for params in settings.iter_params(model_id) if model_id in TEST_MODELS ], ) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_tp_language_generation( model_id: str, parallel_setup: ParallelSetup, @@ -431,7 +431,7 @@ def test_tp_language_generation( for params in settings.iter_params(model_id) if model_id in TEST_MODELS ], ) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_tp_language_embedding( model_id: str, parallel_setup: ParallelSetup, @@ -460,7 +460,7 @@ def test_tp_language_embedding( for params in settings.iter_params(model_id) if model_id in TEST_MODELS ], ) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_tp_multimodal_generation( model_id: str, parallel_setup: ParallelSetup, diff --git a/tests/distributed/test_pp_cudagraph.py b/tests/distributed/test_pp_cudagraph.py index 19414971f2b4..3ca6e7b33a5e 100644 --- a/tests/distributed/test_pp_cudagraph.py +++ b/tests/distributed/test_pp_cudagraph.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING import pytest -from ..utils import compare_two_settings, fork_new_process_for_each_test +from ..utils import compare_two_settings, create_new_process_for_each_test if TYPE_CHECKING: from typing_extensions import LiteralString @@ -18,7 +18,7 @@ if TYPE_CHECKING: "FLASH_ATTN", "FLASHINFER", ]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_pp_cudagraph( monkeypatch: pytest.MonkeyPatch, PP_SIZE: int, diff --git a/tests/entrypoints/llm/test_collective_rpc.py b/tests/entrypoints/llm/test_collective_rpc.py index 39d4810de9e7..64c473c4c538 100644 --- a/tests/entrypoints/llm/test_collective_rpc.py +++ b/tests/entrypoints/llm/test_collective_rpc.py @@ -4,12 +4,12 @@ import pytest from vllm import LLM -from ...utils import fork_new_process_for_each_test +from ...utils import create_new_process_for_each_test @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("backend", ["mp", "ray"]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_collective_rpc(tp_size, backend): if tp_size == 1 and backend == "ray": pytest.skip("Skip duplicate test case") diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index 6bc9bf788761..fa8c66d10309 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -3,10 +3,9 @@ import pytest import vllm -from tests.utils import fork_new_process_for_each_test from vllm.lora.request import LoRARequest -from ..utils import multi_gpu_test +from ..utils import create_new_process_for_each_test, multi_gpu_test MODEL_PATH = "THUDM/chatglm3-6b" @@ -55,7 +54,7 @@ def v1(run_with_both_engines_lora): pass -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_chatglm3_lora(chatglm3_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, @@ -75,7 +74,7 @@ def test_chatglm3_lora(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_chatglm3_lora_tp4(chatglm3_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, @@ -96,7 +95,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index d497ae6b2bc1..0acdaeac6952 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -4,10 +4,9 @@ import pytest import ray import vllm -from tests.utils import fork_new_process_for_each_test from vllm.lora.request import LoRARequest -from ..utils import multi_gpu_test +from ..utils import create_new_process_for_each_test, multi_gpu_test MODEL_PATH = "meta-llama/Llama-2-7b-hf" @@ -82,7 +81,7 @@ def v1(run_with_both_engines_lora): # V1 Test: Failing due to numerics on V1. @pytest.mark.skip_v1 -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_llama_lora(sql_lora_files): llm = vllm.LLM(MODEL_PATH, @@ -97,7 +96,7 @@ def test_llama_lora(sql_lora_files): # Skipping for v1 as v1 doesn't have a good way to expose the num_gpu_blocks # used by the engine yet. @pytest.mark.skip_v1 -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_llama_lora_warmup(sql_lora_files): """Test that the LLM initialization works with a warmup LORA path and is more conservative""" @@ -128,7 +127,7 @@ def test_llama_lora_warmup(sql_lora_files): # V1 Test: Failing due to numerics on V1. @pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_llama_lora_tp4(sql_lora_files): llm = vllm.LLM( @@ -143,7 +142,7 @@ def test_llama_lora_tp4(sql_lora_files): @multi_gpu_test(num_gpus=4) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): llm = vllm.LLM( @@ -159,7 +158,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): @multi_gpu_test(num_gpus=4) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files): llm = vllm.LLM( diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index f596651be01e..ee0d7b5da3a9 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -3,11 +3,12 @@ import pytest import vllm -from tests.utils import fork_new_process_for_each_test from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest from vllm.platforms import current_platform +from ..utils import create_new_process_for_each_test + MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" PROMPT_TEMPLATE = ( @@ -57,7 +58,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: @pytest.mark.xfail( current_platform.is_rocm(), reason="MiniCPM-V dependency xformers incompatible with ROCm") -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_minicpmv_lora(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, @@ -80,7 +81,7 @@ def test_minicpmv_lora(minicpmv_lora_files): @pytest.mark.xfail( current_platform.is_rocm(), reason="MiniCPM-V dependency xformers incompatible with ROCm") -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, @@ -101,7 +102,7 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): @pytest.mark.xfail( current_platform.is_rocm(), reason="MiniCPM-V dependency xformers incompatible with ROCm") -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, diff --git a/tests/lora/test_transfomers_model.py b/tests/lora/test_transfomers_model.py index ff3bfcac5053..f65fb1cdbbd5 100644 --- a/tests/lora/test_transfomers_model.py +++ b/tests/lora/test_transfomers_model.py @@ -3,10 +3,9 @@ import pytest import vllm -from tests.utils import fork_new_process_for_each_test from vllm.lora.request import LoRARequest -from ..utils import multi_gpu_test +from ..utils import create_new_process_for_each_test, multi_gpu_test MODEL_PATH = "ArthurZ/ilama-3.2-1B" @@ -56,7 +55,7 @@ def v1(run_with_both_engines_lora): @pytest.mark.skip_v1 -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_ilama_lora(ilama_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, @@ -77,7 +76,7 @@ def test_ilama_lora(ilama_lora_files): @pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_ilama_lora_tp4(ilama_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, @@ -99,7 +98,7 @@ def test_ilama_lora_tp4(ilama_lora_files): @pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 7cdd037d49ac..92fb2404d8a2 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -17,7 +17,7 @@ from vllm.utils import identity from ....conftest import (IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets, _VideoAssets) -from ....utils import (fork_new_process_for_each_test, large_gpu_mark, +from ....utils import (create_new_process_for_each_test, large_gpu_mark, multi_gpu_marks) from ...utils import check_outputs_equal from .vlm_utils import custom_inputs, model_utils, runners @@ -592,7 +592,7 @@ VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2) get_parametrized_options( VLM_TEST_SETTINGS, test_type=VLMTestType.IMAGE, - fork_new_process_for_each_test=False, + create_new_process_for_each_test=False, )) def test_single_image_models(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, @@ -617,7 +617,7 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str, get_parametrized_options( VLM_TEST_SETTINGS, test_type=VLMTestType.MULTI_IMAGE, - fork_new_process_for_each_test=False, + create_new_process_for_each_test=False, )) def test_multi_image_models(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, @@ -642,7 +642,7 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str, get_parametrized_options( VLM_TEST_SETTINGS, test_type=VLMTestType.EMBEDDING, - fork_new_process_for_each_test=False, + create_new_process_for_each_test=False, )) def test_image_embedding_models(model_type: str, test_case: ExpandableVLMTestArgs, @@ -666,7 +666,7 @@ def test_image_embedding_models(model_type: str, get_parametrized_options( VLM_TEST_SETTINGS, test_type=VLMTestType.VIDEO, - fork_new_process_for_each_test=False, + create_new_process_for_each_test=False, )) def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], @@ -688,7 +688,7 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, get_parametrized_options( VLM_TEST_SETTINGS, test_type=VLMTestType.CUSTOM_INPUTS, - fork_new_process_for_each_test=False, + create_new_process_for_each_test=False, )) def test_custom_inputs_models( model_type: str, @@ -714,9 +714,9 @@ def test_custom_inputs_models( get_parametrized_options( VLM_TEST_SETTINGS, test_type=VLMTestType.IMAGE, - fork_new_process_for_each_test=True, + create_new_process_for_each_test=True, )) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], @@ -740,9 +740,9 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, get_parametrized_options( VLM_TEST_SETTINGS, test_type=VLMTestType.MULTI_IMAGE, - fork_new_process_for_each_test=True, + create_new_process_for_each_test=True, )) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], @@ -766,9 +766,9 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, get_parametrized_options( VLM_TEST_SETTINGS, test_type=VLMTestType.EMBEDDING, - fork_new_process_for_each_test=True, + create_new_process_for_each_test=True, )) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_image_embedding_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], @@ -791,7 +791,7 @@ def test_image_embedding_models_heavy(model_type: str, get_parametrized_options( VLM_TEST_SETTINGS, test_type=VLMTestType.VIDEO, - fork_new_process_for_each_test=True, + create_new_process_for_each_test=True, )) def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], @@ -814,9 +814,9 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, get_parametrized_options( VLM_TEST_SETTINGS, test_type=VLMTestType.CUSTOM_INPUTS, - fork_new_process_for_each_test=True, + create_new_process_for_each_test=True, )) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_custom_inputs_models_heavy( model_type: str, test_case: ExpandableVLMTestArgs, diff --git a/tests/models/decoder_only/vision_language/vlm_utils/case_filtering.py b/tests/models/decoder_only/vision_language/vlm_utils/case_filtering.py index c189e5a761fc..8e825676b8f4 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/case_filtering.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/case_filtering.py @@ -13,9 +13,9 @@ from .types import (EMBEDDING_SIZE_FACTORS, ExpandableVLMTestArgs, ImageSizeWrapper, SizeType, VLMTestInfo, VLMTestType) -def get_filtered_test_settings(test_settings: dict[str, VLMTestInfo], - test_type: VLMTestType, - fork_per_test: bool) -> dict[str, VLMTestInfo]: +def get_filtered_test_settings( + test_settings: dict[str, VLMTestInfo], test_type: VLMTestType, + new_proc_per_test: bool) -> dict[str, VLMTestInfo]: """Given the dict of potential test settings to run, return a subdict of tests who have the current test type enabled with the matching val for fork_per_test. @@ -43,7 +43,7 @@ def get_filtered_test_settings(test_settings: dict[str, VLMTestInfo], # Everything looks okay; keep if this is has correct proc handling if (test_info.distributed_executor_backend - is not None) == fork_per_test: + is not None) == new_proc_per_test: matching_tests[test_name] = test_info return matching_tests @@ -51,14 +51,14 @@ def get_filtered_test_settings(test_settings: dict[str, VLMTestInfo], def get_parametrized_options(test_settings: dict[str, VLMTestInfo], test_type: VLMTestType, - fork_new_process_for_each_test: bool): + create_new_process_for_each_test: bool): """Converts all of our VLMTestInfo into an expanded list of parameters. This is similar to nesting pytest parametrize calls, but done directly through an itertools product so that each test can set things like size factors etc, while still running in isolated test cases. """ matching_tests = get_filtered_test_settings( - test_settings, test_type, fork_new_process_for_each_test) + test_settings, test_type, create_new_process_for_each_test) # Ensure that something is wrapped as an iterable it's not already ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e, ) diff --git a/tests/models/encoder_decoder/audio_language/test_whisper.py b/tests/models/encoder_decoder/audio_language/test_whisper.py index 80d6897da7e0..7897bf113d35 100644 --- a/tests/models/encoder_decoder/audio_language/test_whisper.py +++ b/tests/models/encoder_decoder/audio_language/test_whisper.py @@ -10,7 +10,7 @@ import pytest from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset -from ....utils import fork_new_process_for_each_test, multi_gpu_test +from ....utils import create_new_process_for_each_test, multi_gpu_test PROMPTS = [ { @@ -119,7 +119,7 @@ def run_test( assert output.outputs[0].text == expected -@fork_new_process_for_each_test +@create_new_process_for_each_test() @pytest.mark.core_model @pytest.mark.parametrize( "model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"]) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 465c496f4c0f..e6141b97b10d 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -5,10 +5,10 @@ import pytest from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset -from ..utils import fork_new_process_for_each_test +from ..utils import create_new_process_for_each_test -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_plugin( monkeypatch: pytest.MonkeyPatch, dummy_opt_path: str, @@ -24,7 +24,7 @@ def test_plugin( assert (error_msg in str(excinfo.value)) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_oot_registration_text_generation( monkeypatch: pytest.MonkeyPatch, dummy_opt_path: str, @@ -44,7 +44,7 @@ def test_oot_registration_text_generation( assert rest == "" -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_oot_registration_embedding( monkeypatch: pytest.MonkeyPatch, dummy_gemma2_embedding_path: str, @@ -62,7 +62,7 @@ def test_oot_registration_embedding( image = ImageAsset("cherry_blossom").pil_image.convert("RGB") -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_oot_registration_multimodal( monkeypatch: pytest.MonkeyPatch, dummy_llava_path: str, diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 80d3f78f9f31..3282284b6b27 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -17,7 +17,7 @@ from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS, ModelRegistry) from vllm.platforms import current_platform -from ..utils import fork_new_process_for_each_test +from ..utils import create_new_process_for_each_test from .registry import HF_EXAMPLE_MODELS @@ -45,7 +45,7 @@ def test_registry_imports(model_arch): assert supports_multimodal(model_cls) -@fork_new_process_for_each_test +@create_new_process_for_each_test() @pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [ ("LlamaForCausalLM", False, False, False), ("MllamaForConditionalGeneration", True, False, False), @@ -70,7 +70,7 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): stacklevel=2) -@fork_new_process_for_each_test +@create_new_process_for_each_test() @pytest.mark.parametrize("model_arch,is_pp,init_cuda", [ ("MLPSpeculatorPreTrainedModel", False, False), ("DeepseekV2ForCausalLM", True, False), diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 4b5210cdf074..d6844b8dc5f2 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -10,7 +10,8 @@ import pytest import torch from tests.quantization.utils import is_quant_method_supported -from tests.utils import compare_two_settings, fork_new_process_for_each_test + +from ..utils import compare_two_settings, create_new_process_for_each_test models_4bit_to_test = [ ("facebook/opt-125m", "quantize opt model inflight"), @@ -32,7 +33,7 @@ models_pre_quant_8bit_to_test = [ @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_4bit_to_test) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -45,7 +46,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_pre_qaunt_4bit_to_test) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -57,7 +58,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_pre_quant_8bit_to_test) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -70,7 +71,7 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_4bit_to_test) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -88,7 +89,7 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_4bit_to_test) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_load_pp_4bit_bnb_model(model_name, description) -> None: common_args = [ "--disable-log-stats", diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index d396e52a9ddc..56acf664ab57 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -42,7 +42,7 @@ from transformers import AutoTokenizer from vllm import SamplingParams -from ...utils import fork_new_process_for_each_test +from ...utils import create_new_process_for_each_test from .conftest import (get_output_from_llm_generator, run_equality_correctness_test) @@ -82,7 +82,7 @@ from .conftest import (get_output_from_llm_generator, @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_spec_decode_e2e_with_detokenization(test_llm_generator, batch_size: int): """Run generation with speculative decoding on a batch. Verify the engine @@ -170,7 +170,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, ]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, @@ -244,7 +244,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( ]) @pytest.mark.parametrize("batch_size", [64]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, @@ -300,7 +300,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( ]) @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, @@ -356,7 +356,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( 256, ]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_spec_decode_e2e_greedy_correctness_real_model_bs1( vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, @@ -411,7 +411,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( 64, ]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, @@ -469,7 +469,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_spec_decode_e2e_greedy_correctness_with_preemption( vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, @@ -534,7 +534,7 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( 32, ]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, @@ -594,7 +594,7 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, 64, ]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_skip_speculation(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, @@ -644,7 +644,7 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("output_len", [10]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_disable_speculation(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, @@ -697,7 +697,7 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs, 32, ]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, seed: int): @@ -752,7 +752,7 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, 32, ]) @pytest.mark.parametrize("seed", [1]) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_typical_acceptance_sampling(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, diff --git a/tests/test_utils.py b/tests/test_utils.py index ae4fddd046d4..3660cfa0e49e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -16,7 +16,7 @@ from vllm.utils import (FlexibleArgumentParser, MemorySnapshot, deprecate_kwargs, get_open_port, memory_profiling, merge_async_iterators, supports_kw, swap_dict_values) -from .utils import error_on_warning, fork_new_process_for_each_test +from .utils import create_new_process_for_each_test, error_on_warning @pytest.mark.asyncio @@ -276,7 +276,7 @@ def test_supports_kw(callable,kw_name,requires_kw_only, ) == is_supported -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_memory_profiling(): # Fake out some model loading + inference memory usage to test profiling # Memory used by other processes will show up as cuda usage outside of torch diff --git a/tests/utils.py b/tests/utils.py index 06ba8a2421c1..627cf567afcc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,12 +7,14 @@ import os import signal import subprocess import sys +import tempfile import time import warnings -from contextlib import contextmanager +from contextlib import contextmanager, suppress from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Literal, Optional, Union +import cloudpickle import openai import pytest import requests @@ -703,6 +705,78 @@ def fork_new_process_for_each_test( return wrapper +def spawn_new_process_for_each_test( + f: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to spawn a new process for each test function. + """ + + @functools.wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + # Check if we're already in a subprocess + if os.environ.get('RUNNING_IN_SUBPROCESS') == '1': + # If we are, just run the function directly + return f(*args, **kwargs) + + import torch.multiprocessing as mp + with suppress(RuntimeError): + mp.set_start_method('spawn') + + # Get the module + module_name = f.__module__ + + # Create a process with environment variable set + env = os.environ.copy() + env['RUNNING_IN_SUBPROCESS'] = '1' + + with tempfile.TemporaryDirectory() as tempdir: + output_filepath = os.path.join(tempdir, "new_process.tmp") + + # `cloudpickle` allows pickling complex functions directly + input_bytes = cloudpickle.dumps((f, output_filepath)) + + cmd = [sys.executable, "-m", f"{module_name}"] + + returned = subprocess.run(cmd, + input=input_bytes, + capture_output=True, + env=env) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError(f"Error raised in subprocess:\n" + f"{returned.stderr.decode()}") from e + + return wrapper + + +def create_new_process_for_each_test( + method: Optional[Literal["spawn", "fork"]] = None +) -> Callable[[Callable[_P, None]], Callable[_P, None]]: + """Creates a decorator that runs each test function in a new process. + + Args: + method: The process creation method. Can be either "spawn" or "fork". + If not specified, + it defaults to "spawn" on ROCm platforms and "fork" otherwise. + + Returns: + A decorator to run test functions in separate processes. + """ + if method is None: + method = "spawn" if current_platform.is_rocm() else "fork" + + assert method in ["spawn", + "fork"], "Method must be either 'spawn' or 'fork'" + + if method == "fork": + return fork_new_process_for_each_test + + return spawn_new_process_for_each_test + + def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator: """ Get a pytest mark, which skips the test if the GPU doesn't meet @@ -762,7 +836,7 @@ def multi_gpu_test(*, num_gpus: int): marks = multi_gpu_marks(num_gpus=num_gpus) def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: - func = fork_new_process_for_each_test(f) + func = create_new_process_for_each_test()(f) for mark in reversed(marks): func = mark(func) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 2ec4f7e034af..afbe15b9d46e 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -9,7 +9,6 @@ from concurrent.futures import Future import pytest from transformers import AutoTokenizer -from tests.utils import fork_new_process_for_each_test from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform @@ -19,6 +18,8 @@ from vllm.v1.executor.abstract import Executor, UniProcExecutor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput +from ...utils import create_new_process_for_each_test + if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) @@ -44,7 +45,7 @@ def make_request() -> EngineCoreRequest: ) -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_engine_core(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: @@ -158,7 +159,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): assert len(engine_core.scheduler.running) == 0 -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): """ A basic end-to-end test to verify that the engine functions correctly @@ -208,7 +209,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): _check_engine_state() -@fork_new_process_for_each_test +@create_new_process_for_each_test() def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): """ Test that the engine can handle multiple concurrent batches. diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 004b4dc82f4d..48f451a58968 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -8,7 +8,6 @@ from typing import Optional import pytest from transformers import AutoTokenizer -from tests.utils import fork_new_process_for_each_test from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform @@ -19,6 +18,8 @@ from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, SyncMPClient) from vllm.v1.executor.abstract import Executor +from ...utils import create_new_process_for_each_test + if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) @@ -88,7 +89,7 @@ def echo(self, msg: str, err_msg: Optional[str] = None) -> str: return msg -@fork_new_process_for_each_test +@create_new_process_for_each_test() @pytest.mark.parametrize("multiprocessing_mode", [True, False]) def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, multiprocessing_mode: bool):