diff --git a/tests/v1/distributed/test_async_llm_dp.py b/tests/v1/distributed/test_async_llm_dp.py index 98d6ef7dbf440..60f9017184ea0 100644 --- a/tests/v1/distributed/test_async_llm_dp.py +++ b/tests/v1/distributed/test_async_llm_dp.py @@ -20,13 +20,6 @@ from vllm.v1.metrics.stats import IterationStats, MultiModalCacheStats, Schedule DP_SIZE = int(os.getenv("DP_SIZE", 2)) -engine_args = AsyncEngineArgs( - model="ibm-research/PowerMoE-3b", - enforce_eager=True, - tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), - data_parallel_size=DP_SIZE, -) - async def generate( engine: AsyncLLM, @@ -65,6 +58,13 @@ async def generate( return count, request_id +@pytest.mark.parametrize( + "model", + [ + "ibm-research/PowerMoE-3b", + "hmellor/tiny-random-LlamaForCausalLM", + ], +) @pytest.mark.parametrize( "output_kind", [ @@ -76,7 +76,10 @@ async def generate( @pytest.mark.parametrize("async_scheduling", [True, False]) @pytest.mark.asyncio async def test_load( - output_kind: RequestOutputKind, data_parallel_backend: str, async_scheduling: bool + model: str, + output_kind: RequestOutputKind, + data_parallel_backend: str, + async_scheduling: bool, ): if async_scheduling and data_parallel_backend == "ray": # TODO(NickLucche) Re-enable when async scheduling is supported @@ -107,8 +110,14 @@ async def test_load( with ExitStack() as after: prompt = "This is a test of data parallel" - engine_args.data_parallel_backend = data_parallel_backend - engine_args.async_scheduling = async_scheduling + engine_args = AsyncEngineArgs( + model=model, + enforce_eager=True, + tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), + data_parallel_size=DP_SIZE, + data_parallel_backend=data_parallel_backend, + async_scheduling=async_scheduling, + ) engine = AsyncLLM.from_engine_args( engine_args, stat_loggers=[SimpleStatsLogger] )