Add lora test for tp>1 case for TPU. (#21970)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
XiongfeiWei 2025-08-01 11:56:08 -07:00 committed by GitHub
parent d331759488
commit d84b97a3e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from torch_xla._internal import tpu
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -27,25 +28,31 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
yield yield
def setup_vllm(num_loras: int) -> vllm.LLM: def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
num_scheduler_steps=1, num_scheduler_steps=1,
max_model_len=256, max_model_len=256,
max_seq_len_to_capture=256, max_seq_len_to_capture=256,
max_num_seqs=8, max_num_seqs=8,
tensor_parallel_size=tp,
enable_lora=True, enable_lora=True,
max_loras=num_loras, max_loras=num_loras,
max_lora_rank=8) max_lora_rank=8)
def test_single_lora(): TPU_TENSOR_PARALLEL_SIZES = [1, tpu.num_available_chips()
] if tpu.num_available_chips() > 1 else [1]
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_single_lora(tp: int):
""" """
This test ensures we can run a single LoRA adapter on the TPU backend. This test ensures we can run a single LoRA adapter on the TPU backend.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=1. will force Qwen2.5-3B-Instruct to claim 1+1=1.
""" """
llm = setup_vllm(1) llm = setup_vllm(1, tp)
prompt = "What is 1+1? \n" prompt = "What is 1+1? \n"
@ -63,7 +70,8 @@ def test_single_lora():
assert int(answer) == 1 assert int(answer) == 1
def test_lora_hotswapping(): @pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_lora_hotswapping(tp: int):
""" """
This test ensures we can run multiple LoRA adapters on the TPU backend, even This test ensures we can run multiple LoRA adapters on the TPU backend, even
if we only have space to store 1. if we only have space to store 1.
@ -79,7 +87,7 @@ def test_lora_hotswapping():
for i in range(1, 5) for i in range(1, 5)
] ]
llm = setup_vllm(1) llm = setup_vllm(1, tp)
prompt = "What is 1+1? \n" prompt = "What is 1+1? \n"
@ -94,7 +102,8 @@ def test_lora_hotswapping():
assert int(answer) == i + 1 assert int(answer) == i + 1
def test_multi_lora(): @pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_multi_lora(tp: int):
""" """
This test ensures we can run multiple LoRA adapters on the TPU backend, when This test ensures we can run multiple LoRA adapters on the TPU backend, when
we have enough space to store all of them. we have enough space to store all of them.
@ -109,7 +118,7 @@ def test_multi_lora():
for i in range(1, 5) for i in range(1, 5)
] ]
llm = setup_vllm(4) llm = setup_vllm(4, tp)
prompt = "What is 1+1? \n" prompt = "What is 1+1? \n"