mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 07:25:02 +08:00
[TPU] Enable gemma3-27b with TP>1 on multi-chips. (#17335)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
parent
5ea5c514da
commit
9765940824
@ -8,6 +8,7 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from torch_xla._internal import tpu
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -63,3 +64,45 @@ def test_basic(
|
|||||||
output = vllm_outputs[0][1]
|
output = vllm_outputs[0][1]
|
||||||
|
|
||||||
assert "1024" in output or "0, 1" in output
|
assert "1024" in output or "0, 1" in output
|
||||||
|
|
||||||
|
|
||||||
|
TP_SIZE_8 = 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
|
reason="This is a test for TPU only")
|
||||||
|
@pytest.mark.skipif(tpu.num_available_chips() < TP_SIZE_8,
|
||||||
|
reason=f"This test requires {TP_SIZE_8} TPU chips.")
|
||||||
|
def test_gemma3_27b_with_text_input_and_tp(
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
model = "google/gemma-3-27b-it"
|
||||||
|
max_tokens = 16
|
||||||
|
tensor_parallel_size = TP_SIZE_8
|
||||||
|
max_num_seqs = 4
|
||||||
|
prompts = [
|
||||||
|
"A robot may not injure a human being",
|
||||||
|
"It is only with the heart that one can see rightly;",
|
||||||
|
"The greatest glory in living lies not in never falling,",
|
||||||
|
]
|
||||||
|
answers = [
|
||||||
|
" or, through inaction, allow a human being to come to harm.",
|
||||||
|
" what is essential is invisible to the eye.",
|
||||||
|
" but in rising every time we fall.",
|
||||||
|
]
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
max_num_batched_tokens=256,
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
tensor_parallel_size=tensor_parallel_size) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
|
||||||
|
# vllm_outputs is a list of tuples whose first element is the token id
|
||||||
|
# and the second element is the output (including the prompt).
|
||||||
|
for output, answer in zip(vllm_outputs, answers):
|
||||||
|
generated_text = output[1]
|
||||||
|
assert answer in generated_text
|
||||||
|
|||||||
@ -30,6 +30,7 @@ class TpuPlatform(Platform):
|
|||||||
dispatch_key: str = "XLA"
|
dispatch_key: str = "XLA"
|
||||||
ray_device_key: str = "TPU"
|
ray_device_key: str = "TPU"
|
||||||
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
||||||
|
simple_compile_backend: str = "openxla"
|
||||||
|
|
||||||
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
|
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user