mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:35:01 +08:00
Add lora test for tp>1 case for TPU. (#21970)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
parent
d331759488
commit
d84b97a3e3
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
from torch_xla._internal import tpu
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -27,25 +28,31 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
|
||||
yield
|
||||
|
||||
|
||||
def setup_vllm(num_loras: int) -> vllm.LLM:
|
||||
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
|
||||
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
||||
num_scheduler_steps=1,
|
||||
max_model_len=256,
|
||||
max_seq_len_to_capture=256,
|
||||
max_num_seqs=8,
|
||||
tensor_parallel_size=tp,
|
||||
enable_lora=True,
|
||||
max_loras=num_loras,
|
||||
max_lora_rank=8)
|
||||
|
||||
|
||||
def test_single_lora():
|
||||
TPU_TENSOR_PARALLEL_SIZES = [1, tpu.num_available_chips()
|
||||
] if tpu.num_available_chips() > 1 else [1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
|
||||
def test_single_lora(tp: int):
|
||||
"""
|
||||
This test ensures we can run a single LoRA adapter on the TPU backend.
|
||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
|
||||
will force Qwen2.5-3B-Instruct to claim 1+1=1.
|
||||
"""
|
||||
|
||||
llm = setup_vllm(1)
|
||||
llm = setup_vllm(1, tp)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
@ -63,7 +70,8 @@ def test_single_lora():
|
||||
assert int(answer) == 1
|
||||
|
||||
|
||||
def test_lora_hotswapping():
|
||||
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
|
||||
def test_lora_hotswapping(tp: int):
|
||||
"""
|
||||
This test ensures we can run multiple LoRA adapters on the TPU backend, even
|
||||
if we only have space to store 1.
|
||||
@ -79,7 +87,7 @@ def test_lora_hotswapping():
|
||||
for i in range(1, 5)
|
||||
]
|
||||
|
||||
llm = setup_vllm(1)
|
||||
llm = setup_vllm(1, tp)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
@ -94,7 +102,8 @@ def test_lora_hotswapping():
|
||||
assert int(answer) == i + 1
|
||||
|
||||
|
||||
def test_multi_lora():
|
||||
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
|
||||
def test_multi_lora(tp: int):
|
||||
"""
|
||||
This test ensures we can run multiple LoRA adapters on the TPU backend, when
|
||||
we have enough space to store all of them.
|
||||
@ -109,7 +118,7 @@ def test_multi_lora():
|
||||
for i in range(1, 5)
|
||||
]
|
||||
|
||||
llm = setup_vllm(4)
|
||||
llm = setup_vllm(4, tp)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user