diff --git a/tests/tpu/lora/test_lora.py b/tests/tpu/lora/test_lora.py index b26bdd34d890e..4c47b8c43caff 100644 --- a/tests/tpu/lora/test_lora.py +++ b/tests/tpu/lora/test_lora.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +from torch_xla._internal import tpu import vllm from vllm.lora.request import LoRARequest @@ -27,25 +28,31 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch): yield -def setup_vllm(num_loras: int) -> vllm.LLM: +def setup_vllm(num_loras: int, tp: int) -> vllm.LLM: return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", num_scheduler_steps=1, max_model_len=256, max_seq_len_to_capture=256, max_num_seqs=8, + tensor_parallel_size=tp, enable_lora=True, max_loras=num_loras, max_lora_rank=8) -def test_single_lora(): +TPU_TENSOR_PARALLEL_SIZES = [1, tpu.num_available_chips() + ] if tpu.num_available_chips() > 1 else [1] + + +@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES) +def test_single_lora(tp: int): """ This test ensures we can run a single LoRA adapter on the TPU backend. We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which will force Qwen2.5-3B-Instruct to claim 1+1=1. """ - llm = setup_vllm(1) + llm = setup_vllm(1, tp) prompt = "What is 1+1? \n" @@ -63,7 +70,8 @@ def test_single_lora(): assert int(answer) == 1 -def test_lora_hotswapping(): +@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES) +def test_lora_hotswapping(tp: int): """ This test ensures we can run multiple LoRA adapters on the TPU backend, even if we only have space to store 1. @@ -79,7 +87,7 @@ def test_lora_hotswapping(): for i in range(1, 5) ] - llm = setup_vllm(1) + llm = setup_vllm(1, tp) prompt = "What is 1+1? \n" @@ -94,7 +102,8 @@ def test_lora_hotswapping(): assert int(answer) == i + 1 -def test_multi_lora(): +@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES) +def test_multi_lora(tp: int): """ This test ensures we can run multiple LoRA adapters on the TPU backend, when we have enough space to store all of them. @@ -109,7 +118,7 @@ def test_multi_lora(): for i in range(1, 5) ] - llm = setup_vllm(4) + llm = setup_vllm(4, tp) prompt = "What is 1+1? \n"