Merge 9aaed80cc85f65438f859bdd19fe90d6b712be5c into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
weiyu 2025-12-25 00:52:07 +00:00 committed by GitHub
commit 737b3079ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 12 additions and 6784 deletions

View File

@ -92,7 +92,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |

View File

View File

@ -1,139 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from torch_xla._internal import tpu
import vllm
from vllm.lora.request import LoRARequest
# This file contains tests to ensure that LoRA works correctly on the TPU
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
# for this. The adapters are:
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
# from 1 to 4.
# These adapters are trained using a standard huggingface peft training script,
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
# 100 training iterations with a training batch size of 100.
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
return vllm.LLM(
model="Qwen/Qwen2.5-3B-Instruct",
max_model_len=256,
max_num_seqs=8,
tensor_parallel_size=tp,
enable_lora=True,
max_loras=num_loras,
max_lora_rank=8,
)
TPU_TENSOR_PARALLEL_SIZES = (
[1, tpu.num_available_chips()] if tpu.num_available_chips() > 1 else [1]
)
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_single_lora(tp: int):
"""
This test ensures we can run a single LoRA adapter on the TPU backend.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=1.
"""
llm = setup_vllm(1, tp)
prompt = "What is 1+1? \n"
lora_request = LoRARequest(
"lora_adapter_1",
1,
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter",
)
output = (
llm.generate(
prompt,
sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
lora_request=lora_request,
)[0]
.outputs[0]
.text
)
answer = output.strip()[0]
assert answer.isdigit()
assert int(answer) == 1
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_lora_hotswapping(tp: int):
"""
This test ensures we can run multiple LoRA adapters on the TPU backend, even
if we only have space to store 1.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
"""
lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
lora_requests = [
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
for i in range(1, 5)
]
llm = setup_vllm(1, tp)
prompt = "What is 1+1? \n"
for i, req in enumerate(lora_requests):
output = (
llm.generate(
prompt,
sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
lora_request=req,
)[0]
.outputs[0]
.text
)
answer = output.strip()[0]
assert answer.isdigit()
assert int(answer) == i + 1
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_multi_lora(tp: int):
"""
This test ensures we can run multiple LoRA adapters on the TPU backend, when
we have enough space to store all of them.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
"""
lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
lora_requests = [
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
for i in range(1, 5)
]
llm = setup_vllm(4, tp)
prompt = "What is 1+1? \n"
for i, req in enumerate(lora_requests):
output = (
llm.generate(
prompt,
sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
lora_request=req,
)[0]
.outputs[0]
.text
)
answer = output.strip()[0]
assert answer.isdigit()
assert int(output.strip()[0]) == i + 1

View File

@ -1,86 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import glob
import os
import tempfile
import depyf
def test_tpu_compilation():
temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
from vllm import LLM, SamplingParams
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, through inaction",
" what is essential ",
" but in rising ",
]
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
N = 1
sampling_params = SamplingParams(temperature=0.7, top_p=1.0, n=N, max_tokens=16)
llm = LLM(
model="Qwen/Qwen2-1.5B-Instruct",
max_num_batched_tokens=256,
max_model_len=256,
max_num_seqs=32,
enforce_eager=False,
)
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert generated_text.startswith(answer)
compiled_codes = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py"))
)
for i, compiled_code in enumerate(compiled_codes):
print("{} file: {}".format(i + 1, compiled_code))
# We should only trigger Dynamo compilation 2 times:
# 1. Forward pass without kv_caches
# 2. Forward pass with kv_caches
# Check we have 2 compiled codes
assert len(compiled_codes) == 2
kv_cache_prefix = "kv_cache"
attn_prefix = "ragged_paged_attention"
def extract_compiled_index(s):
parts = s.replace(".", "_").split("_")
numbers = [int(part) for part in parts if part.isdigit()]
return numbers[0]
# Check all the compilations are as expected. The dump files include the
# captured graph for the forward function of the nn.Module.
compiled_fns = sorted(
glob.glob(os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")),
key=lambda s: extract_compiled_index(s),
)
for i, compiled_fn in enumerate(compiled_fns):
print("{} file: {}".format(i + 1, compiled_fn))
# The first compilation should not have any kv_caches
with open(compiled_fns[0]) as f:
content = f.read()
assert kv_cache_prefix not in content
# The second compilation should have kv_caches and the
# ragged_paged_attention
with open(compiled_fns[1]) as f:
content = f.read()
assert kv_cache_prefix in content and attn_prefix in content

View File

@ -1,34 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.config import CompilationMode
from ..utils import compare_two_settings
# --enforce-eager on TPU causes graph compilation
# this times out default Health Check in the MQLLMEngine,
# so we set the timeout here to 30s
def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_RPC_TIMEOUT", "30000")
compare_two_settings(
"Qwen/Qwen2.5-1.5B-Instruct",
arg1=[
"--max-model-len=256",
"--max-num-seqs=32",
"--enforce-eager",
f"-O{CompilationMode.DYNAMO_TRACE_ONCE}",
],
arg2=[
"--max-model-len=256",
"--max-num-seqs=32",
"--enforce-eager",
f"-O{CompilationMode.STOCK_TORCH_COMPILE}",
],
env1={},
env2={},
)

View File

@ -1,88 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the Pallas MOE implementation.
Run `pytest tests/kernels/moe/test_moe_pallas.py`.
"""
import pytest
import torch
import torch_xla
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe as pallas_moe
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as torch_moe,
)
from vllm.platforms import current_platform
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
NUM_EXPERTS = [8, 64]
EP_SIZE = [1]
TOP_KS = [2, 6]
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_pallas_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
):
import torch_xla.core.xla_model as xm
with torch.device(xm.xla_device()):
a = torch.randn((m, k), dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
w2 = torch.randn((e, k, n), dtype=dtype) / 10
score = torch.randn((m, e), dtype=dtype)
# TODO: Support ep
if ep_size > 1:
pytest.skip("No support for ep_size > 1 yet")
else:
e_map = None
# Run both implementations
torch_output = torch_moe(
hidden_states=a,
w1=w1,
w2=w2,
gating_output=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False,
)
pallas_output = pallas_moe(
hidden_states=a,
w1=w1,
w2=w2,
gating_output=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False,
)
torch_xla.sync(wait=False)
# Compare outputs
torch.testing.assert_close(
pallas_output.cpu(),
torch_output.cpu(),
atol=2e-2,
rtol=0,
)

View File

@ -1,52 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import lm_eval
import pytest
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
@dataclass
class GSM8KAccuracyTestConfig:
model_name: str
expected_value: float
def get_model_args(self) -> str:
return f"pretrained={self.model_name},max_model_len=4096,max_num_seqs=32"
# NOTE: Accuracy scores measured on GPUs.
ACCURACY_CONFIGS = [
GSM8KAccuracyTestConfig(
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
expected_value=0.76,
), # no bias
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
# so only one of these tests can run in a single call to pytest. As
# a follow-up, move this into the LM-EVAL section of the CI.
# GSM8KAccuracyTestConfig(
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
# expected_value=0.66), # bias in QKV layers
]
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(),
tasks="gsm8k",
batch_size="auto",
)
EXPECTED_VALUE = config.expected_value
measured_value = results["results"][TASK][FILTER]
assert (
measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"

View File

@ -1,177 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A basic correctness check for TPUs
Run `pytest tests/v1/tpu/test_basic.py`.
"""
from typing import TYPE_CHECKING
import pytest
from torch_xla._internal import tpu
from vllm.platforms import current_platform
if TYPE_CHECKING:
from tests.conftest import VllmRunner
else:
VllmRunner = object
MODELS = [
"Qwen/Qwen2.5-1.5B-Instruct",
# TODO: Enable this model when fixed.
# "Qwen/Qwen1.5-MoE-A2.7B",
# TODO: Enable this models with v6e
# "Qwen/Qwen2-7B-Instruct",
# "meta-llama/Llama-3.1-8B",
]
TENSOR_PARALLEL_SIZES = [1]
MAX_NUM_REQS = [16, 1024]
# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]
@pytest.mark.skipif(
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
def test_basic(
vllm_runner: type[VllmRunner],
model: str,
max_tokens: int,
tensor_parallel_size: int,
max_num_seqs: int,
) -> None:
prompt = (
"The next numbers of the sequence "
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt]
with vllm_runner(
model,
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
max_num_batched_tokens=1024,
max_model_len=8192,
gpu_memory_utilization=0.7,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
output = vllm_outputs[0][1]
assert "1024" in output or "0, 1" in output
@pytest.mark.skip(reason="Temporarily disabled due to timeout")
@pytest.mark.skipif(
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
)
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("max_num_seqs", [16])
def test_phi3(
vllm_runner: type[VllmRunner],
max_tokens: int,
max_num_seqs: int,
) -> None:
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, by violating privacy",
" what is essential is love.",
" but in rising every time we fall.",
]
# test head dim = 96
model = "microsoft/Phi-3-mini-128k-instruct"
with vllm_runner(
model, max_num_batched_tokens=256, max_num_seqs=max_num_seqs
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
# vllm_outputs is a list of tuples whose first element is the token id
# and the second element is the output (including the prompt).
for output, answer in zip(vllm_outputs, answers):
generated_text = output[1]
assert answer in generated_text
TP_SIZE_8 = 8
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
@pytest.mark.skipif(
tpu.num_available_chips() < TP_SIZE_8,
reason=f"This test requires {TP_SIZE_8} TPU chips.",
)
def test_gemma3_27b_with_text_input_and_tp(
vllm_runner: type[VllmRunner],
) -> None:
model = "google/gemma-3-27b-it"
max_tokens = 16
tensor_parallel_size = TP_SIZE_8
max_num_seqs = 4
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, through inaction, allow a human being to come to harm.",
" what is essential is invisible to the eye.",
" but in rising every time we fall.",
]
with vllm_runner(
model,
max_num_batched_tokens=256,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
# vllm_outputs is a list of tuples whose first element is the token id
# and the second element is the output (including the prompt).
for output, answer in zip(vllm_outputs, answers):
generated_text = output[1]
assert answer in generated_text
@pytest.mark.skipif(
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
)
def test_w8a8_quantization(
vllm_runner: type[VllmRunner],
) -> None:
model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
max_tokens = 5
tensor_parallel_size = 1
max_num_seqs = 4
prompt = (
"The next numbers of the sequence "
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt]
with vllm_runner(
model,
max_num_batched_tokens=64,
max_model_len=4096,
gpu_memory_utilization=0.7,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
output = vllm_outputs[0][1]
assert "1024" in output or "0, 1" in output

View File

@ -1,78 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest
import torch
import torch_xla
import vllm.v1.attention.backends.pallas # noqa: F401
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
@pytest.mark.parametrize("page_size", [32, 33])
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
def test_kv_cache_update_kernel(
page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int
):
page_num = 1000
padded_num_tokens = 128
kv_cache_cpu = torch.zeros(
(page_num * page_size, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu",
)
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
new_kv_cpu = torch.randn(
(padded_num_tokens, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu",
)
new_kv_xla = new_kv_cpu.to(torch_xla.device())
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], dtype=np.int32)
num_kv_update_slices = len(slice_lens)
kv_cache_start_indices = np.array(
[
page_size * 2 - 7,
page_size * 2,
page_size * 3,
page_size * 4 + 6,
page_size * 5 + 7,
page_size * 6 + 8,
page_size * 15 + 3,
],
dtype=np.int32,
)
new_kv_cache_indices = np.concatenate(
[np.array([0], dtype=np.int32), np.cumsum(slice_lens[:-1])]
)
slot_mapping = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1
)
slot_mapping = np.transpose(slot_mapping)
slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", dtype=torch.int32)
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
num_kv_update_slices_xla = torch.tensor(
[num_kv_update_slices], device=torch_xla.device(), dtype=torch.int32
)
torch_xla.sync()
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla,
slot_mapping_xla,
kv_cache_xla,
num_kv_update_slices_xla,
page_size,
num_slices_per_block,
)
kv_cache_xla.copy_(new_kv_cache_xla)
torch_xla.sync()
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, slice_lens):
kv_cache_cpu[ci : ci + sl, :, :] = new_kv_cpu[ni : ni + sl, :, :]
assert torch.allclose(kv_cache_xla.cpu(), kv_cache_cpu, atol=1e-4, rtol=1e-4)

View File

@ -1,94 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test:
* Tests for MMEncoderAttention layer
"""
import pytest
import torch
import torch_xla
import torch_xla.core
import torch_xla.core.xla_model
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()
def ref_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
"""
Native implementation of scaled dot product attention without mask:
- query, key, value: [batch_size, seq_len, num_heads, head_size]
- attn_mask: [batch_size, seq_len, seq_len]
"""
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.matmul(attn_weights, value).transpose(1, 2)
return out
BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("device", [torch_xla.core.xla_model.xla_device()])
def test_mha_attn_forward(
batch_size: int,
seq_len: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
device: str,
):
current_platform.seed_everything(0)
# These are expected to be f32
q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device)
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
scale = 1.0 / head_size**0.5
attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(q, k, v)
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
q = q.reshape(batch_size, seq_len, num_heads, head_size)
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
if num_queries_per_kv > 1:
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
ref_output = ref_attention(
q,
k,
v,
scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size)
# torch_xla flash_attn kernel is less accurate but much faster
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-3)

View File

@ -1,76 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import openai
import pytest
from vllm.multimodal.utils import encode_image_url
from vllm.platforms import current_platform
from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS
from ...utils import RemoteOpenAIServer
@pytest.fixture(scope="session")
def url_encoded_image(local_asset_server) -> dict[str, str]:
return {
image_asset: encode_image_url(local_asset_server.get_image_asset(image_asset))
for image_asset in TEST_IMAGE_ASSETS
}
@pytest.mark.asyncio
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
async def test_basic_vision(model_name: str, url_encoded_image: dict[str, str]):
pytest.skip("Skip this test until it's fixed.")
def whats_in_this_image_msg(url):
return [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": url}},
],
}
]
server_args = [
"--max-model-len",
"1024",
"--max-num-seqs",
"16",
"--gpu-memory-utilization",
"0.95",
"--trust-remote-code",
"--max-num-batched-tokens",
"576",
# NOTE: max-num-batched-tokens>=mm_item_size
"--disable_chunked_mm_input",
]
# Server will pre-compile on first startup (takes a long time).
with RemoteOpenAIServer(
model_name, server_args, max_wait_seconds=600
) as remote_server:
client: openai.AsyncOpenAI = remote_server.get_async_client()
# Other requests now should be much faster
for image_url in TEST_IMAGE_ASSETS:
image_url = url_encoded_image[image_url]
chat_completion_from_url = await client.chat.completions.create(
model=model_name,
messages=whats_in_this_image_msg(image_url),
max_completion_tokens=24,
temperature=0.0,
)
result = chat_completion_from_url
assert result
choice = result.choices[0]
assert choice.finish_reason == "length"
message = choice.message
message = result.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"

View File

@ -1,100 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import ANY, patch
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.v1.attention.backends.pallas import PallasAttentionBackendImpl, PallasMetadata
def test_ragged_paged_attention():
# We verify that the kernel inputs such as sliding_window, etc. are passed
# in from the model correctly.
# The correctness of the paged attention kernel is tested in the kernel
# library.
num_heads = 4
head_size = 128
scale = 1.0
num_kv_heads = 4
sliding_window = 128
logits_soft_cap = 50.0
attn_impl = PallasAttentionBackendImpl(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=sliding_window,
kv_cache_dtype="auto",
logits_soft_cap=logits_soft_cap,
attn_type=AttentionType.DECODER,
)
class FakeAttentionLayer:
_q_scale_float: float
_k_scale_float: float
_v_scale_float: float
layer = FakeAttentionLayer()
layer._q_scale_float = 1.0
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
num_tokens = 16
num_blocks = 1024
block_size = 16
query = torch.zeros(num_tokens, num_heads * head_size)
key = torch.zeros(num_tokens, num_kv_heads * head_size)
value = torch.zeros(num_tokens, num_kv_heads * head_size)
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
max_num_reqs = 8
max_num_blocks_per_req = 8
num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32)
block_tables = torch.zeros(
(max_num_reqs, max_num_blocks_per_req), dtype=torch.int32
)
context_lens = torch.ones((max_num_reqs,), dtype=torch.int32)
query_lens = [1] * max_num_reqs
query_start_loc = torch.cumsum(
torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32
)
num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=block_tables,
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_seqs,
num_kv_update_slices=num_kv_update_slices,
num_slices_per_kv_cache_update_block=8,
)
with patch("torch.ops.xla.ragged_paged_attention") as mock_ragged_paged_attention:
attn_impl.forward(
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
mock_ragged_paged_attention.assert_called_once_with(
ANY, # query
ANY, # kv_cache
ANY, # context_lens
ANY, # block_tables
ANY, # query_start_loc
ANY, # num_seqs
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=scale,
sliding_window=sliding_window,
soft_cap=logits_soft_cap,
k_scale=1.0,
v_scale=1.0,
)

View File

@ -1,150 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A basic performance regression test for TPUs
Run `pytest tests/v1/tpu/test_perf.py`.
"""
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
import pytest
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.tokenizers import get_tokenizer
if TYPE_CHECKING:
from tests.conftest import VllmRunner
else:
VllmRunner = object
@dataclass
class TestParams:
model: str
num_prompts: int
prefix_len: int
decode_len: int
expected_avg_time: float
err_tol: float
TEST_PARAMS = [
# TODO: Cannot run a series of tests because:
# RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed:
# open(/dev/vfio/0): Device or resource busy: Device or resource busy;
# Couldn't open iommu group /dev/vfio/0
# => Investigate
# TestParams(
# model="Qwen/Qwen2.5-1.5B-Instruct",
# num_prompts=1,
# prefix_len=10,
# decode_len=5,
# expected_avg_time=0.03,
# err_tol=0.01,
# ),
# TestParams(
# model="Qwen/Qwen2.5-1.5B-Instruct",
# num_prompts=10,
# prefix_len=100,
# decode_len=50,
# expected_avg_time=0.234,
# err_tol=0.020,
# ),
TestParams(
model="Qwen/Qwen2.5-1.5B-Instruct",
num_prompts=64,
prefix_len=500,
decode_len=50,
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
# tpu: v5lite (old vllm CI/CD)
# expected_avg_time=1.4,
# err_tol=0.30,
# (This is the active CI/CD instance)
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
# tpu: v6e (current vllm CI/CD)
expected_avg_time=1.7, # measured with VLLM_XLA_CACHE_PATH=
err_tol=0.20,
),
]
NUM_WARMUPS = 5
NUM_RUNS = 10
MAX_MODEL_LEN = 1024
MAX_NUM_SEQS = 32
GPU_UTIL = 0.9
@pytest.mark.skipif(
not current_platform.is_tpu(),
reason="This is a basic performance test for TPU only",
)
@pytest.mark.parametrize("params", TEST_PARAMS)
def test_perf(
vllm_runner: type[VllmRunner],
params: TestParams,
) -> None:
tokenizer = get_tokenizer(
params.model, tokenizer_mode="auto", trust_remote_code=True
)
prompts = []
for i in range(params.num_prompts):
prefix_token_ids = np.random.randint(
0, tokenizer.vocab_size, size=params.prefix_len
).tolist()
prompt = tokenizer.decode(prefix_token_ids)
prompts.append(prompt)
print(
"-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format(
len(prompts), params.prefix_len, params.decode_len
)
)
sampling_params = SamplingParams(
max_tokens=params.decode_len, temperature=1.0, min_p=0.0
)
with vllm_runner(
params.model,
max_num_batched_tokens=MAX_MODEL_LEN,
max_model_len=MAX_MODEL_LEN,
max_num_seqs=MAX_NUM_SEQS,
gpu_memory_utilization=GPU_UTIL,
enforce_eager=False,
tensor_parallel_size=1,
) as vllm_model:
print(" -- Warmup / Compile")
for i in range(NUM_WARMUPS):
_ = vllm_model.generate(prompts, sampling_params)
print(" -- Benchmarking... ")
times = []
for i in range(NUM_RUNS):
start_time = time.time()
_ = vllm_model.generate(prompts, sampling_params)
times.append(time.time() - start_time)
avg_time = sum(times) / len(times)
print(" -- avg_time = {}".format(avg_time))
print(
" -- expected_avg_time = {} with err_tol = {}".format(
params.expected_avg_time, params.err_tol
)
)
diff = avg_time - params.expected_avg_time
ok = diff < params.err_tol
if diff < -params.err_tol:
print(
" !! WARNING !! Performance has improved by {}, "
"it may be necessary to fine-tune the "
"expected_avg_time = {}".format(-diff, params.expected_avg_time)
)
assert ok, " !! ERROR !! Regression detected"

View File

@ -1,105 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
from vllm import LLM
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
def test_sampler_different(model_name: str):
"""
Test significantly different sampling params to assert the model produces
different results.
"""
llm = LLM(
model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=512,
max_num_batched_tokens=256,
)
prompts = ["Write a short story about a robot that dreams for the first time."]
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
output = llm.generate(prompts, sampling_params)
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
output2 = llm.generate(prompts, sampling_params)
assert output[0].outputs[0].text != output2[0].outputs[0].text
with pytest.raises(ValueError):
# Unsupported `seed` param.
sampling_params = SamplingParams(temperature=0.3, seed=42)
output2 = llm.generate(prompts, sampling_params)
# Batch-case with TopK/P
for B in [4, 16]:
p = prompts * B
sampling_params = [
SamplingParams(
temperature=0.1,
min_p=0.8,
max_tokens=64,
# Vary number of ks
top_k=random.randint(4, 12),
top_p=random.random(),
)
for _ in range(B)
]
# Make sure first two reqs have the same K/P
sampling_params[0] = sampling_params[1]
output = llm.generate(p, sampling_params)
# There are natural numerical instabilities that make it difficult
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
# TODO TPU will appear busy if we fan-out test params here
@pytest.mark.parametrize("n_prompts", [1])
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
def test_logprobs(model_name: str, n_prompts: int):
"""
Request top logprobs with different sampling settings and check
that results contains the requested number, ordered ascendingly.
"""
def check_num_logprobs(logprobs, expected_num: int):
for step in logprobs:
prev_logp = 1.0
# order by rank
sorted_step = dict(sorted(step.items(), key=lambda item: item[1].rank))
# Can contain the sampled token
assert len(step) == expected_num or len(step) == expected_num + 1
# Check results are ordered by prob value
for rankno, (tid, logp) in enumerate(sorted_step.items()):
assert logp.logprob <= prev_logp
prev_logp = logp.logprob
assert logp.rank == rankno + 1
llm = LLM(
model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=128,
max_num_batched_tokens=128,
)
prompts = [
"Write a short story about a robot that dreams for the first time."
] * n_prompts
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64, logprobs=4)
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64, logprobs=4)
topkp_sampling_params = SamplingParams(
temperature=0.4, max_tokens=64, logprobs=4, top_k=12, top_p=0.5
)
for sp in [greedy_sampling_params, regular_sampling_params, topkp_sampling_params]:
output = llm.generate(prompts, sp)
for o in output:
check_num_logprobs(o.outputs[0].logprobs, 4)

View File

@ -1,78 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import tempfile
import numpy as np
import pytest
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tpu import TPUModelLoader
def _setup_environment(model):
engine_args = EngineArgs(
model=model,
)
vllm_config = engine_args.create_engine_config()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
1,
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo",
)
# Under single worker mode, full model is init first and then
# partitioned using GSPMD.
ensure_model_parallel_initialized(1, 1)
return vllm_config
MESH = None
def _get_spmd_mesh():
global MESH
if MESH is None:
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
return MESH
@pytest.mark.parametrize(
"model",
[
"Qwen/Qwen2-1.5B-Instruct",
# Skip large models due to CI runner disk space limitations
# "meta-llama/Llama-3.1-8B-Instruct",
# "meta-llama/Llama-3.1-70B-Instruct",
],
)
def test_tpu_model_loader(model):
# Skip the 70B test if there are less than 8 chips
# TODO: Query using torch xla API, the query API is not working
# with SPMD now. However, This test is running under SPMD mode.
if "70B" in model and xr.global_runtime_device_count() < 8:
pytest.skip(
"Skipping 70B model if the TPU VM has less than 8 chips to \
avoid OOM."
)
vllm_config = _setup_environment(model)
loader = TPUModelLoader(load_config=vllm_config.load_config)
mesh = _get_spmd_mesh()
model = loader.load_model(vllm_config, vllm_config.model_config, mesh)
del model
gc.collect()

View File

@ -1,149 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest
import torch
import torch_xla
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
import torch_xla.core.xla_model as xm
BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6
def test_topk_equivalence_to_native_impl():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
# Random top-k values between 1 and 10.
k = torch.randint(1, 10, (BATCH_SIZE,))
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), VOCAB_SIZE)
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
assert torch.allclose(result_native, result_tpu)
def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
probs = logits.softmax(dim=-1)
# Random top-p values between 0 and 1.
p = torch.rand((BATCH_SIZE,))
# Set p=1 for ~50% of requests in the batch (top-p disabled).
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), 1)
no_op_k = torch.tensor([VOCAB_SIZE])
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), k=no_op_k, p=p)
# Verify that the masked logit's probability sums to at least p.
probs.masked_fill_(logits_masked.isinf(), 0)
masked_prob_sum = probs.sum(dim=-1)
torch_xla.sync()
# Perform assertion on CPU.
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
def test_topp_basic():
with torch.device(xm.xla_device()):
logits = torch.tensor(
[
[math.log(0.2), math.log(0.3), math.log(0.5)],
[math.log(0.5), math.log(0.1), math.log(0.4)],
]
)
result = apply_top_k_top_p_tpu(
logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79])
)
torch_xla.sync()
# Expect the smallest elements to be dropped.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
def test_topp_select_all():
with torch.device(xm.xla_device()):
logits = torch.tensor(
[
[math.log(0.2), math.log(0.3), math.log(0.5)],
[math.log(0.5), math.log(0.1), math.log(0.4)],
]
)
result = apply_top_k_top_p_tpu(
logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0])
)
torch_xla.sync()
assert torch.allclose(logits.cpu(), result.cpu())
def test_topp_with_ties():
with torch.device(xm.xla_device()):
# Input has multiple math.log(0.3).
logits = torch.tensor(
[[math.log(0.3), math.log(0.3), math.log(0.3), math.log(0.1)]]
)
result = apply_top_k_top_p_tpu(
logits=logits.clone(), k=torch.tensor([4]), p=torch.tensor([0.2])
)
torch_xla.sync()
# All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal
# probability of being chosen).
expected_result = logits.clone().cpu()
expected_result[0, 3] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
def test_both_topk_topp():
with torch.device(xm.xla_device()):
logits = torch.tensor(
[
[math.log(0.2), math.log(0.3), math.log(0.5)],
[math.log(0.5), math.log(0.1), math.log(0.4)],
]
)
# Set k=1 for the first batch.
result = apply_top_k_top_p_tpu(
logits=logits.clone(), k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79])
)
torch_xla.sync()
# Since for the first batch k=1, expect only the largest element gets
# selected.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[0, 1] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())

View File

@ -1,78 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests whether TPU Int8 computation is enabled correctly.
Run `pytest tests/quantization/test_tpu_int8.py`.
"""
import pytest
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.tpu_int8 import TPUInt8LinearMethod
from vllm.platforms import current_platform
from ...models.registry import HF_EXAMPLE_MODELS
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
@pytest.mark.skipif(
not current_platform.is_tpu(), reason="TPU Int8 is only enabled for TPUs."
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize(
"hf_overrides",
[
# w8a8 dynamic activation
{
"quantization_config": {
"quant_method": "tpu_int8",
"activation_scheme": "dynamic",
}
}
],
)
def test_model_tpu_int8(
vllm_runner,
model: str,
dtype: str,
max_tokens: int,
hf_overrides: dict,
monkeypatch,
) -> None:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_transformers_version(on_fail="skip")
activation_scheme = hf_overrides.get("quantization_config", {}).get(
"activation_scheme"
)
quantize_activation = activation_scheme == "dynamic"
# Allows using apply_model
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
# Prevent error from re-initializing cache
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
prompts = [
"A robot may not injure a human being",
]
answers = [
"or kill a human being",
]
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
def check_model(model):
for name, module in model.named_modules():
if not isinstance(module, LinearBase):
continue
quant_method = module.quant_method
assert isinstance(quant_method, TPUInt8LinearMethod)
assert quant_method.quantize_activation == quantize_activation
vllm.apply_model(check_model)
outputs = vllm.generate_greedy(prompts, max_tokens)
for (_, output), answer in zip(outputs, answers):
assert answer in output

View File

@ -1,93 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
import numpy as np
import pytest
import torch
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.linear import QKVParallelLinear
@pytest.fixture(autouse=True)
def setup_environment():
# This is a fake config used for init dist env.
# QKVParallelLinear needs dist env to be initialized.
engine_args = EngineArgs(
model="Qwen/Qwen2-1.5B-Instruct",
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
1,
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo",
)
ensure_model_parallel_initialized(1, 1)
yield
MESH = None
def _get_spmd_mesh():
global MESH
if MESH is None:
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
return MESH
@pytest.mark.parametrize("bias", [False, True])
# `xr.use_spmd()` will set a global state, and this state is not reversible.
# Therefore, non-SPMD tests should be run before SPMD tests.
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
@pytest.mark.parametrize("device", ["cpu", "xla"])
@torch.no_grad()
def test_xla_qkv_linear(bias, mesh, device):
torch.manual_seed(123)
qkv_linear = QKVParallelLinear(
hidden_size=4096,
head_size=128,
total_num_heads=32,
total_num_kv_heads=8,
bias=bias,
params_dtype=torch.bfloat16,
return_bias=False,
)
qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10
if bias:
qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data)
xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh)
qkv_linear = qkv_linear.to(device)
xla_qkv_linear = xla_qkv_linear.to(device)
input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10
input_tensor = input_tensor.to(device)
output = qkv_linear(input_tensor)
xla_output = xla_qkv_linear(input_tensor)
assert torch.allclose(output.cpu(), xla_output.cpu())

View File

@ -1,587 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.attention.layer import Attention
from vllm.config import (
CacheConfig,
ModelConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils.mem_constants import GiB_bytes
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.worker.tpu_model_runner import (
TPUModelRunner,
_get_padded_num_reqs_with_upper_limit,
_get_padded_token_len,
_get_req_paddings,
_get_token_paddings,
)
def get_vllm_config():
model_config = ModelConfig(
model="facebook/opt-125m",
dtype="bfloat16", # TPUs typically use bfloat16
seed=42,
)
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
is_encoder_decoder=model_config.is_encoder_decoder,
)
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
)
return vllm_config
def get_model_runner(vllm_config):
device = "xla:0" # Mocking TPU device
return TPUModelRunner(vllm_config, device)
@pytest.fixture
def model_runner():
# Patchers have already been started at module level.
vllm_config = get_vllm_config()
return get_model_runner(vllm_config)
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
new_reqs = []
num_scheduled_tokens = {}
total_num_scheduled_tokens = 0
for req_id in req_ids:
new_reqs.append(
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
mm_features=[],
sampling_params=SamplingParams(),
pooling_params=PoolingParams(),
block_ids=([0],), # block_ids should be tuple[list[int]]
num_computed_tokens=0,
lora_request=None,
)
)
num_scheduled_tokens[req_id] = 3
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
def _is_req_scheduled(model_runner, req_id: str) -> bool:
return req_id in model_runner.input_batch.req_id_to_index
def _is_req_added(model_runner, req_id: str) -> bool:
return req_id in model_runner.requests
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
"""Check if the request state block IDs match the block table.
This function handles both legacy BlockTable and new MultiGroupBlockTable
structures for backward compatibility.
"""
req_index = model_runner.input_batch.req_id_to_index[req_id]
multi_group_block_table = model_runner.input_batch.block_table
req_state = model_runner.requests[req_id]
# Access the first block table from MultiGroupBlockTable
# This is safe since we currently only use single KV cache groups
block_table = multi_group_block_table[0]
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
# Extract the first group's block IDs
if isinstance(req_state.block_ids[0], list):
# New format: tuple[list[int], ...] - extract first group
req_block_ids = req_state.block_ids[0]
else:
# Legacy format: list[int] - use directly
req_block_ids = req_state.block_ids
if block_table.num_blocks_per_row[req_index] != len(req_block_ids):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
block_table_values = block_table.block_table.np[req_index, :num_blocks]
return (block_table_values == req_block_ids).all()
def test_update_states_new_request(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_finished(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
)
model_runner._update_states(scheduler_output)
assert not _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)
def test_update_states_request_resumed(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)
# resume req
cached_req_data = CachedRequestData(
req_ids=[req_id],
resumed_req_ids={req_id},
new_token_ids=[[]],
all_token_ids={req_id: scheduler_output.scheduled_new_reqs[0].prompt_token_ids},
new_block_ids=[([],)],
num_computed_tokens=[0],
num_output_tokens=[0],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_no_changes(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_unscheduled(model_runner):
req_ids = ("req_0", "req_1")
# new reqs
scheduler_output = _schedule_new_request(*req_ids)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])
assert _is_req_added(model_runner, req_ids[1])
assert _is_req_scheduled(model_runner, req_ids[1])
# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])
assert _is_req_added(model_runner, req_ids[1])
assert not _is_req_scheduled(model_runner, req_ids[1])
def test_get_paddings():
# Bucketed padding
min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
# Bucketed padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding.
max_token_size, padding_gap = 1024, 0
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 256, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
assert actual_paddings == expected_paddings
def test_get_padded_token_len():
min_token_size, max_token_size, padding_gap = 16, 512, 64
paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
assert _get_padded_token_len(paddings, 1) == 16
assert _get_padded_token_len(paddings, 16) == 16
assert _get_padded_token_len(paddings, 20) == 32
assert _get_padded_token_len(paddings, 300) == 320
assert _get_padded_token_len(paddings, 512) == 512
def test_get_padded_num_reqs_with_upper_limit():
assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8
assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16
assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32
assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28
def test_get_req_paddings():
assert _get_req_paddings(1, 32) == [8, 16, 32]
assert _get_req_paddings(8, 32) == [8, 16, 32]
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
vllm_config = model_runner.vllm_config
with (
pytest.raises(ValueError, match=error_msg),
set_current_vllm_config(vllm_config),
):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_0,
kv_sharing_target_layer_name=layer_1,
),
layer_1: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_1,
),
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
vllm_config = model_runner.vllm_config
with (
pytest.raises(ValueError, match=error_msg),
set_current_vllm_config(vllm_config),
):
fwd_context = {
layer_0: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_0,
),
layer_1: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_1,
# invalid layer: cross_attn.atn doesn't exist!
kv_sharing_target_layer_name=invalid_layer,
),
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
vllm_config = model_runner.vllm_config
with (
pytest.raises(ValueError, match=error_msg),
set_current_vllm_config(vllm_config),
):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_0,
),
layer_1: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name=layer_1,
),
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_without_kv_sharing():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_0,
),
layer_1: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_1,
),
}
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 1_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
model_runner = get_model_runner(vllm_config)
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 2
assert len(model_runner.shared_kv_cache_layers) == 0
available_memory = 20 * GiB_bytes
# page size for each layer KV can be calculated as
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
kv_cache_config = get_kv_cache_configs(
vllm_config, [kv_cache_spec], [available_memory]
)[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 2
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
# max_context_len = available_memory / (page_size / block_size) / num_caches
# max_context_len = 5GB / (512KB / 128) / 2 = 655360
assert max_context_len == 655360
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 512kb)
kv_cache_config.num_blocks = 1
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
kv_cache_tensor.size = kv_cache_spec[
kv_cache_tensor.shared_by[0]
].page_size_bytes
model_runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
# check layer 1 kv cache does NOT share memory with layer 0
assert id(layer_1_kv) != id(layer_0_kv)
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_init_kv_cache_with_kv_sharing_valid():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_0,
),
layer_1: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
),
}
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
model_runner = get_model_runner(vllm_config)
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 1
assert layer_0 in kv_cache_spec
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 512KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 2 * 20480 # 20GB / 512KB
kv_cache_config = get_kv_cache_configs(
vllm_config, [kv_cache_spec], [available_memory]
)[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == (2 * 655360)
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (512kb)
kv_cache_config.num_blocks = 1
kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes
model_runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
# check layer 1 kv cache shares memory with layer 0
assert id(layer_1_kv) == id(layer_0_kv)
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
vllm_config = get_vllm_config()
vllm_config.model_config.max_model_len = 32000
vllm_config.scheduler_config.max_num_seqs = 1200
model_runner = get_model_runner(vllm_config)
# verify model runner will adjust num_reqs to avoid SMEM OOM.
assert model_runner.num_reqs_most_model_len == 1200
# num_page_per_req = 32k // 128
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
assert model_runner.num_reqs_max_model_len == 524

View File

@ -66,7 +66,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
)
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"

View File

@ -227,28 +227,3 @@ class MMEncoderAttention(CustomOp):
"XPU only supports FLASH_ATTN for vision attention."
)
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
def forward_tpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
f"MMEncoderAttention on TPU only supports PALLAS backend, "
f"but got {self.attn_backend}."
)
if cu_seqlens is None:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out
logger.warning_once(
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
"Falling back to SDPA implementation.",
)
return self._forward_sdpa(query, key, value, cu_seqlens)

View File

@ -1,99 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import torch
from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_INFERENCE
from .base_device_communicator import DeviceCommunicatorBase
USE_RAY = parallel_config = (
get_current_vllm_config().parallel_config.distributed_executor_backend == "ray"
)
logger = init_logger(__name__)
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups,
)
if USE_RAY:
from vllm.v1.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can
# be simply calculated as follows.
global_rank = self.global_rank
global_world_size = self.global_world_size
if USE_RAY:
logger.info("TpuCommunicator initialized with RAY")
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
# nodes is computed by the total number of TPUs divided by the
# number of TPU accelerators per node, to account for clusters
# with both CPUs and TPUs.
num_nodes = ray_utils.get_num_tpu_nodes()
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
if num_nodes_in_pg > 0:
num_nodes = num_nodes_in_pg
local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
else:
logger.info("TpuCommunicator initialized with MP")
# Sanity: Verify we run on a single host
num_hosts = torch_xla.tpu.num_tpu_workers()
assert num_hosts == 1
# Get the current number of TPUs (we have locally)
local_world_size = torch_xla.tpu.num_available_chips()
# Get current rank
local_rank = global_rank % local_world_size
# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# chip is actually visible. Otherwise the TPU driver will fail to
# initialize because the number of devices would be different from
# the number of visible worker addresses.
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()
self.groups = create_optimized_replica_groups()
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# TODO: Remove the groups specification after XLA compiler can support
# auto-reordering the ring order for all-reduce.
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(input_, dim=dim)

View File

@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Literal
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger
@ -251,9 +250,6 @@ class TpKVTopology:
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@ -261,7 +257,7 @@ class TpKVTopology:
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
return not (self.is_mla or self.is_kv_layout_blocks_first)
@property
def tp_size(self) -> int:

View File

@ -499,7 +499,6 @@ class MooncakeConnectorWorker:
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
self._use_pallas = self.kv_topo._use_pallas
self.zmq_ctx = zmq.Context()
self.async_zmq_ctx = zmq.asyncio.Context()

View File

@ -983,7 +983,6 @@ class NixlConnectorWorker:
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
self._use_pallas = self.kv_topo._use_pallas
self._physical_blocks_per_logical_kv_block = 1
def _nixl_handshake(
@ -1641,9 +1640,6 @@ class NixlConnectorWorker:
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
kv_cache_layout = (
self.kv_cache_layout
if not self.use_host_buffer
@ -1814,9 +1810,7 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0:
return
split_k_and_v = not (
self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
)
split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
sample_cache = list(self.device_kv_caches.values())[0][0]
for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
assert block_size_ratio > 1, "Only nP < nD supported currently."

View File

@ -1,188 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import OrderedDict
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.distributed.spmd as xs
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
logger = init_logger(__name__)
class XlaQKVParallelLinear(nn.Module):
def __init__(self, qkv_linear: nn.Module, mesh: Optional["xs.Mesh"] = None):
super().__init__()
assert isinstance(qkv_linear, QKVParallelLinear)
self.skip_bias_add = qkv_linear.skip_bias_add
self.return_bias = qkv_linear.return_bias
assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."
self.q_weight: Parameter
self.k_weight: Parameter
self.v_weight: Parameter
self.q_bias: Parameter | None
self.k_bias: Parameter | None
self.v_bias: Parameter | None
self._load_weights_from_qkv_linear(qkv_linear)
if mesh is not None:
self._shard_weight(mesh)
def _shard_weight(self, mesh: "xs.Mesh"):
self.q_weight = Parameter(self.q_weight.to("xla"), requires_grad=False)
self.k_weight = Parameter(self.k_weight.to("xla"), requires_grad=False)
self.v_weight = Parameter(self.v_weight.to("xla"), requires_grad=False)
xs.mark_sharding(self.q_weight, mesh, ("x", None))
xs.mark_sharding(self.k_weight, mesh, ("x", None))
xs.mark_sharding(self.v_weight, mesh, ("x", None))
if self.q_bias is not None:
assert self.k_bias is not None and self.v_bias is not None, (
"QKVParallelLinear should have q, k, and v biases together."
)
self.q_bias = Parameter(self.q_bias.to("xla"), requires_grad=False)
xs.mark_sharding(self.q_bias, mesh, ("x",))
self.k_bias = Parameter(self.k_bias.to("xla"), requires_grad=False)
xs.mark_sharding(self.k_bias, mesh, ("x",))
self.v_bias = Parameter(self.v_bias.to("xla"), requires_grad=False)
xs.mark_sharding(self.v_bias, mesh, ("x",))
def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
# The weight of qkv linear is a concatenation of q, k, and v weights
# along the output dimension.
qkv_weight = qkv_linear.weight.data.cpu()
q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
k_weight = Parameter(
qkv_weight[q_proj_size : q_proj_size + k_proj_size], requires_grad=False
)
v_weight = Parameter(
qkv_weight[q_proj_size + k_proj_size :], requires_grad=False
)
self.register_parameter("q_weight", q_weight)
self.register_parameter("k_weight", k_weight)
self.register_parameter("v_weight", v_weight)
if qkv_linear.bias is not None:
q_bias = Parameter(qkv_linear.bias[:q_proj_size], requires_grad=False)
k_bias = Parameter(
qkv_linear.bias[q_proj_size : q_proj_size + k_proj_size],
requires_grad=False,
)
v_bias = Parameter(
qkv_linear.bias[q_proj_size + k_proj_size :], requires_grad=False
)
self.register_parameter("q_bias", q_bias)
self.register_parameter("k_bias", k_bias)
self.register_parameter("v_bias", v_bias)
else:
self.register_parameter("q_bias", None)
self.register_parameter("k_bias", None)
self.register_parameter("v_bias", None)
def forward(self, input):
# Same forward functionality as QKVParallelLinear, but doing qkv porj
# separately.
q_bias = self.q_bias if not self.skip_bias_add else None
k_bias = self.k_bias if not self.skip_bias_add else None
v_bias = self.v_bias if not self.skip_bias_add else None
q_proj = F.linear(input, self.q_weight, q_bias)
k_proj = F.linear(input, self.k_weight, k_bias)
v_proj = F.linear(input, self.v_weight, v_bias)
# The q/k/v projections will be split outside of the QKVParallelLinear.
# Because we are replacing XlaQKVParallelLinear with the
# QKVParallelLinear, we need to concatenate q, k, and v projections to
# match the output shape of the QKVParallelLinear implementation even if
# it seems to be redundant.
# The concat and the following split will be noop, and should be
# optimized away by the compiler.
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
output_bias = (
torch.cat([q_bias, k_bias, v_bias], dim=-1) if self.skip_bias_add else None
)
if not self.return_bias:
return qkv_proj
return qkv_proj, output_bias
def partition_column_parallel_linear(
layer: torch.nn.Module, mesh: xs.Mesh
) -> torch.nn.Module:
assert isinstance(layer, ColumnParallelLinear)
xs.mark_sharding(layer.weight, mesh, ("x", None))
logger.debug("Applied column-parallel sharding to %s", layer)
return layer
def partition_row_parallel_linear(
layer: torch.nn.Module, mesh: xs.Mesh
) -> torch.nn.Module:
assert isinstance(layer, RowParallelLinear)
xs.mark_sharding(layer.weight, mesh, (None, "x"))
logger.debug("Applied row-parallel sharding to %s", layer)
return layer
def partition_qkv_parallel_linear(
layer: torch.nn.Module, mesh: xs.Mesh
) -> torch.nn.Module:
assert isinstance(layer, QKVParallelLinear)
xla_layer = XlaQKVParallelLinear(layer, mesh)
logger.debug("Applied qkv parallel sharding to %s", layer)
return xla_layer
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict(
[
("QKVParallelLinear", partition_qkv_parallel_linear),
("ColumnParallelLinear", partition_column_parallel_linear),
("RowParallelLinear", partition_row_parallel_linear),
]
)
def get_fqn(module):
# Get the fully qualified name of the module
return module.__class__.__qualname__
def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
"""
Recursively check a PyTorch model and apply appropriate sharding based on
the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
Args:
model: torch.nn.Module to process
mesh: An XLA SPMD mesh object used for sharding
"""
def _process_module(module, name=None, parent=None):
for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
if get_fqn(module) == module_type:
wrapped_module = wrapping_func(module, mesh)
assert parent is not None and name is not None, (
"Top Level module is not expected to be wrapped."
)
if wrapped_module is not module:
# Wrapped module and module are different py object.
# The original module should be replaced by the
# wrapped_module.
logger.debug("replace %s with %s", module, wrapped_module)
setattr(parent, name, wrapped_module)
module = wrapped_module
break
for child_name, child_module in list(module.named_children()):
_process_module(child_module, child_name, module)
_process_module(model)

View File

@ -1,6 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]

View File

@ -1,141 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
import torch_xla.core.xla_builder as xb
from torch.library import impl
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
@jax.jit
def bgmv_jax(inputs, loras, idxs):
return jnp.einsum(
"td,tX,Xld->tl",
inputs,
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
loras,
)
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
jax_import_guard()
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
T, _ = inputs.shape
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
_, L, _ = loras.shape
return torch.empty((T, L), device=inputs.device)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
if output_tensor.shape[1] > outputs.shape[1]:
outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
if add_inputs:
return output_tensor + outputs[:limit, : output_tensor.shape[1]]
else:
return outputs[:limit, : output_tensor.shape[1]]
def bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
scaling (float, optional): Scalar multiplier applied to the output.
"""
return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
outputs = F.pad(
outputs,
(
slice_offset,
output_tensor.shape[1] - (slice_offset + slice_size),
0,
0,
),
)
if add_inputs:
return output_tensor + outputs
else:
return outputs

View File

@ -1,358 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
import torch_xla
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.punica_wrapper.utils import convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from .punica_base import PunicaWrapperBase
class PunicaWrapperTPU(PunicaWrapperBase):
"""
PunicaWrapperTPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
# PunicaWrapperBase defines some tensors with dtype=torch.int64, which
# isn't supported by the TPU. So convert those tensors to int32.
# Not all of them are used by the TPU so only convert the useful ones.
self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32)
self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
self._sampler_indices_padded = self._sampler_indices_padded.to(
dtype=torch.int32
)
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True)
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
@property
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA.
"""
return self._embeddings_indices[:]
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
return self._sampler_indices_padded[:]
def shrink(
self,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
def expand(
self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool
):
return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs)
def expand_slice(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
) -> torch.Tensor:
return bgmv_expand_slice(
x,
w_t_all,
y,
self._get_token_lora_indices(x),
y_offset,
y_slice_size,
add_inputs,
)
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
torch.ops.xla.dynamo_set_buffer_donor_(y, True)
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)):
lora_s = lora_a_stacked[slice_idx]
y_s = self.shrink(x, lora_s, scale)
y[slice_idx, :, :] = y_s # type: ignore[index]
return y
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> torch.Tensor:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = 0
for slice_idx in range(len(lora_b_stacked)):
y = self.expand_slice(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
return y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> torch.Tensor:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# Embedding layer only needs the expand op
return self.expand(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> torch.Tensor:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will not be changed in-place.
x (torch.Tensor): Input tensor (T, E)
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if buffer is None:
r = lora_b_stacked[0].size(-1)
T = x.size(0)
buffer = torch.zeros(
(len(output_slices), T, r),
dtype=x.dtype,
device=x.device,
)
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
return self.add_expand(
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
)
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
return y.view_as(y_org)
# This performs the same tensor ops as the base method, except it does them
# on the CPU then transfers the results to the TPU
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
):
# Make sure we don't accidentally collect outside operations
torch_xla.sync()
# Pad the prompt mapping to avoid running into recompiles on the TPU
# TODO: Should this happen inside mapping internally? If so how can we
# avoid having backend specific LoRAMapping classes?
mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping)
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
0, # extra_vocab_size
"cpu",
)
self._token_lora_indices = self._pad_to_shape(
base_indices, self._token_lora_indices.shape, dims=1
).to(self.device)
self._sampler_indices = self._pad_to_shape(
sampler_indices, self._sampler_indices.shape, dims=1
).to(self.device)
self._sampler_indices_padded = self._pad_to_shape(
sampler_indices_padded, self._sampler_indices_padded.shape, dims=1
).to(self.device)
self._embeddings_indices = self._pad_to_shape(
embeddings_indices, self._embeddings_indices.shape, dims=2
).to(self.device)
self.indices_len[:] = indices_len
def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
self.batch_size = 1
self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[
: self.batch_size
]
def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
num_reqs = len(prompt_mapping)
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
# import
MIN_NUM_SEQS = 8
padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
pad_len = padded_num_reqs - num_reqs
padding = [-1] * pad_len
return tuple(list(prompt_mapping) + padding)
def _pad_to_shape(self, src, target_shape, dims=1):
if dims == 1:
pad_len = target_shape[0] - src.shape[0]
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
else:
pad_rows = target_shape[0] - src.shape[0]
pad_cols = target_shape[1] - src.shape[1]
return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32)

View File

@ -67,21 +67,15 @@ else:
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)

View File

@ -1,83 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn.functional as F
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
"""
Compute the histogram of an int32 tensor. The bin edges are defined by the
min and max values, with step = 1.
"""
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
assert min <= max, "min must be less than or equal to max."
def searchsorted(
sorted_sequence: torch.Tensor, values_to_search: torch.Tensor
) -> torch.Tensor:
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to(
input.device
)
return searchsorted(bin_edges, input).to(torch.int32)
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
renormalize: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
assert expert_map is None, "expert_map is not supported for pallas MoE."
import torch_xla.experimental.custom_kernel # noqa: F401
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
intermediate_size = w2.shape[-1]
device = hidden_states.device
dtype = hidden_states.dtype
assert (num_tokens * topk) % 16 == 0, (
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
f"16 but got {num_tokens * topk}"
)
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
topk_indices = topk_indices.flatten()
topk_argsort_indices = topk_indices.argsort()
topk_argsort_revert_indices = topk_argsort_indices.argsort()
token_indices = torch.arange(num_tokens, device=device).repeat_interleave(topk)
token_indices = token_indices[topk_argsort_indices]
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
x = hidden_states[token_indices]
x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True)
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True)
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
x = x * topk_weights.unsqueeze(dim=-1)
x = x.sum(dim=-2)
x = x.reshape(orig_shape)
return x

View File

@ -38,10 +38,6 @@ if current_platform.is_cuda_alike():
else:
TritonExperts = None # type: ignore
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
@ -403,53 +399,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=layer.custom_routing_function,
)
def forward_tpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not layer.use_grouped_topk
assert layer.num_expert_group is None
assert layer.topk_group is None
assert layer.custom_routing_function is None
assert layer.apply_router_weight_on_input is False
if layer.scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU."
)
if layer.e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for TPU."
)
assert layer.activation == "silu", (
f"{layer.activation} is not supported for TPU."
)
assert layer.routed_scaling_factor == 1.0, (
f"routed_scaling_factor {layer.routed_scaling_factor} is "
"not supported for TPU."
)
if (
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for TPU.")
return fused_moe_pallas(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=layer.top_k,
gating_output=router_logits,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
renormalize=layer.renormalize,
)
if current_platform.is_tpu():
forward_native = forward_tpu
elif current_platform.is_cpu():
if current_platform.is_cpu():
forward_native = forward_cpu
elif current_platform.is_xpu():
forward_native = forward_xpu

View File

@ -11,7 +11,6 @@ logger = init_logger(__name__)
QuantizationMethods = Literal[
"awq",
"deepspeedfp",
"tpu_int8",
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
@ -130,12 +129,10 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .ptpc_fp8 import PTPCFp8Config
from .rtn import RTNConfig
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
method_to_config: dict[str, type[QuantizationConfig]] = {
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"fp_quant": FPQuantConfig,

View File

@ -19,9 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel,
)
from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
@ -29,7 +26,6 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}

View File

@ -1,106 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
import torch
from functorch.experimental.control_flow import cond # noqa: F401
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "Requires TPU."
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "ScaledMMXLA requires running on TPU."
if c.is_static_input_scheme:
return False, "ScaledMMXLA requires dynamic activation scales."
if not c.input_symmetric:
return False, "ScaledMMXLA requires symmetric activation scales."
if not c.is_channelwise:
return False, "ScaledMMXLA requires channelwise weight scales"
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# [out, in] (different than cutlass_scaled_mm)
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
)
# WEIGHT SCALE
# XLA kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
# [out_channel,] (different than cutlass_scaled_mm)
weight_scale = weight_scale.squeeze(-1)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# Only support symmetric dynamic activation quantization.
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
# Filter warning for cond usage in apply_weights. It is okay
# to specialize the graph since bias is not dynamic.
warnings.filterwarnings(
"ignore",
message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501
)
def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
return x
def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
return x + bias
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
out = torch.ops.xla.quantized_matmul_int8(
x,
w_q,
w_s,
quantize_activation=True,
)
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])

View File

@ -1,139 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
QuantizationMethods,
)
from vllm.model_executor.parameter import ModelWeightParameter
ACTIVATION_SCHEMES = ["none", "dynamic"]
class Int8TpuConfig(QuantizationConfig):
"""Int8 Quantization Config class for TPU Backend."""
def __init__(
self,
activation_scheme: str = "none",
) -> None:
super().__init__()
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
def get_name(self) -> QuantizationMethods:
return "tpu_int8"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError("This function should not be called with TPU Backend")
@staticmethod
def get_config_filenames() -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(activation_scheme=activation_scheme)
def get_quant_method(
self, layer: Module, prefix: str
) -> Optional["TPUInt8LinearMethod"]:
if isinstance(layer, LinearBase):
return TPUInt8LinearMethod(self)
return None
class TPUInt8LinearMethod(LinearMethodBase):
"""Int8 Linear method for TPU Quant."""
def __init__(self, quant_config: Int8TpuConfig):
self.quant_config = quant_config
self.quantize_activation = False
if self.quant_config.activation_scheme == "dynamic":
self.quantize_activation = True
def create_weights(
self,
layer: Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
def _quantize_weight(
self, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
weight_dtype = weight.dtype
weight = weight.cpu().to(torch.float32)
n_bit = 8
eps = 1e-5
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))
max_val = weight.abs().amax(dim=-1, keepdim=True)
max_val = max_val.clamp(min=eps)
qscale = max_val / max_int
qweight = torch.clamp(
torch.round(weight * (1.0 / qscale)), min_int, max_int
).to(torch.int8)
qscale = qscale.squeeze().to(weight_dtype)
return qweight, qscale
def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = Parameter(layer.weight.data, requires_grad=False)
device = layer.weight.device
qweight, qscale = self._quantize_weight(layer.weight)
qweight = qweight.to(device)
qscale = qscale.to(device)
layer.weight = Parameter(qweight, requires_grad=False)
layer.scale = Parameter(qscale, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
try:
import torch_xla.experimental.custom_kernel # noqa: F401
except ImportError as err:
raise ImportError(
"Please install torch_xla by following the instructions at "
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
"to run vLLM on TPU."
) from err
weight = layer.weight
scale = layer.scale
out = torch.ops.xla.quantized_matmul_int8(
x, weight, scale, quantize_activation=self.quantize_activation
)
if bias is not None:
out = out + bias
return out

View File

@ -30,7 +30,6 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator,
safetensors_weights_iterator,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
logger = init_logger(__name__)
@ -241,22 +240,6 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.pt_load_map_location,
)
if current_platform.is_tpu():
from vllm.platforms.tpu import USE_TPU_INFERENCE
if not USE_TPU_INFERENCE:
# In PyTorch XLA, we should call `torch_xla.sync`
# frequently so that not too many ops are accumulated
# in the XLA program.
import torch_xla
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
torch_xla.sync(wait=False)
weights_iterator = _xla_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.

View File

@ -1,118 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
class TPUModelLoader(DefaultModelLoader):
"""
A TPU model loader for model loading under SPMD mode.
"""
def load_model(
self,
vllm_config: VllmConfig,
model_config: ModelConfig,
mesh: xs.Mesh | None = None,
) -> nn.Module:
# Initialize model and load weights on CPU. Then, during SPMD partition,
# weights are sharded and transferred to TPUs.
self.counter_before_loading_weights = time.perf_counter()
model_config = vllm_config.model_config
assert model_config.quantization is None, "Quantization not supported"
target_device = torch.device("cpu")
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
load_format = vllm_config.load_config.load_format
if load_format != "dummy":
weights_to_load = {name for name, _ in model.named_parameters()}
all_weights = self.get_all_weights(model_config, model)
loaded_weights = model.load_weights(all_weights)
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights
- self.counter_before_loading_weights,
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)
else:
logger.info("Use dummy weight during weight loading.")
process_weights_after_loading(model, model_config, target_device)
counter_before_partition = time.perf_counter()
model = model.eval()
model = model.to("xla")
shard_model(model, mesh)
counter_after_partition = time.perf_counter()
logger.info(
"Partition model took %.2f seconds",
counter_after_partition - counter_before_partition,
)
# Ensure the model is properly loaded.
self._check_model_is_loaded(mesh, model)
# Need to torch compile after model sharding are done. Because the
# compiler hints ('xs.mark_sharding') are torch ops.
if not model_config.is_multimodal_model:
model.model = torch.compile(model.model, backend="openxla")
else:
model.language_model.model = torch.compile(
model.language_model.model, backend="openxla"
)
return model
def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None:
"""
Ensure the model is properly loaded.
1. All model parameters and buffers are on XLA device.
2. Non-SPMD friendly layers are replaced as expected.
"""
device = xm.xla_device()
device_type = str(device.type)
# Check parameters
for name, param in model.named_parameters():
assert param.device.type == device_type, (
f"Parameter {name} is on {param.device.type} instead of {device_type}"
)
# Check buffers
for name, buffer in model.named_buffers():
assert buffer.device.type == device_type, (
f"Buffer {name} is on {buffer.device.type} instead of {device_type}"
)
for module in model.modules():
if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"):
raise AssertionError(
"QKVParallelLinear should be replaced by \
XlaQKVParallelLinear under SPMD mode."
)

View File

@ -1,287 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from typing import TYPE_CHECKING, Optional, cast
import torch
from tpu_info import device
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from typing import TypeAlias
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
ParamsType: TypeAlias = SamplingParams | PoolingParams
else:
BlockSize = None
VllmConfig = None
PoolingParams = None
ParamsType = None
logger = init_logger(__name__)
USE_TPU_INFERENCE = False
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
device_name: str = "tpu"
device_type: str = "tpu"
dispatch_key: str = "XLA"
ray_device_key: str = "TPU"
dist_backend: str = "gloo"
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
simple_compile_backend: str = "openxla"
supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
@classmethod
def import_kernels(cls) -> None:
# Do not import vllm._C
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
logger.info("Using Pallas V1 backend.")
return AttentionBackendEnum.PALLAS.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.PALLAS,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention.")
return backend
logger.info_once(
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
)
return AttentionBackendEnum.PALLAS
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.tpu.set_device(device)
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
chip_type, _ = device.get_local_chips()
return f"TPU {chip_type.name}"
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
@classmethod
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
return torch.finfo(dtype).min, torch.finfo(dtype).max
@classmethod
def can_update_inplace(cls):
return False
@classmethod
def get_lora_vocab_padding_size(cls) -> int:
return 1
@classmethod
def inference_mode(cls):
return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationMode, CUDAGraphMode
cache_config = vllm_config.cache_config
# For v0, the default block size is 16.
if cache_config and cache_config.block_size is None:
cache_config.block_size = cast(BlockSize, 16)
compilation_config = vllm_config.compilation_config
# TPU only supports DYNAMO_TRACE_ONCE compilation mode
if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
logger.info(
"[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\
disabling cudagraph."
)
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
if (
compilation_config.cudagraph_mode is None
or compilation_config.cudagraph_mode.max_cudagraph_mode()
!= CUDAGraphMode.NONE
):
logger.info(
"[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if compilation_config.backend == "":
compilation_config.backend = "openxla"
assert vllm_config.speculative_config is None, (
"TPU does not support speculative decoding"
)
model_config = vllm_config.model_config
if model_config is not None and model_config.dtype in (
torch.float16,
torch.float32,
):
logger.warning(
"The TPU backend currently does not support %s. "
"Using bfloat16 instead.",
model_config.dtype,
)
model_config.dtype = torch.bfloat16
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment]
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend"
)
if (
scheduler_config.is_multimodal_model
and not scheduler_config.disable_chunked_mm_input
):
logger.warning(
"TPU does not support running Multimodal models"
" without setting `--disable_chunked_mm_input`. "
"Forcing --disable_chunked_mm_input."
)
scheduler_config.disable_chunked_mm_input = True
if model_config and model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled."
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
return False
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
@classmethod
def validate_request(
cls,
prompt: PromptType,
params: ParamsType,
processed_inputs: ProcessorInputs,
) -> None:
"""Raises if this request is unsupported on this platform"""
from vllm.sampling_params import SamplingParams, SamplingType
if (
isinstance(params, SamplingParams)
and params.sampling_type == SamplingType.RANDOM_SEED
):
raise ValueError("Torch XLA does not support per-request seed.")
@classmethod
@torch.compile(backend="openxla")
def insert_blocks_to_device(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
@classmethod
@torch.compile(backend="openxla")
def swap_out_blocks_to_host(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
"""tpu blocks to cpu blocks"""
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
@classmethod
def use_sync_weight_loader(cls) -> bool:
return True
@classmethod
def check_max_model_len(cls, max_model_len: int) -> int:
"""
Check max_model_len for the current platform.
"""
logger.warning(
"--max-model-len is not specified, "
"it's currently using model's default length %d, "
"which might be too large."
"Please input with --max-model-len based on your "
"request input length and output length, to avoid "
"unnecessary degradation.",
max_model_len,
)
return max_model_len
try:
from tpu_inference.platforms import (
@ -291,5 +14,7 @@ try:
TpuPlatform = TpuInferencePlatform # type: ignore
USE_TPU_INFERENCE = True
except ImportError:
logger.info("tpu_inference not found, using vLLM's TpuPlatform")
logger.error(
"tpu_inference not found, please install tpu_inference to run vllm on TPU"
)
pass

View File

@ -186,20 +186,6 @@ class UsageMessage:
except Exception:
return False
def _report_torch_xla_usage(self) -> bool:
try:
import torch_xla
self.gpu_count = torch_xla.runtime.world_size()
self.gpu_type = torch_xla.tpu.get_tpu_type()
self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[
"bytes_limit"
]
self.cuda_runtime = "torch_xla"
return True
except Exception:
return False
def _report_usage_once(
self,
model_architecture: str,
@ -217,9 +203,7 @@ class UsageMessage:
if current_platform.is_cuda():
self.cuda_runtime = torch.version.cuda
if current_platform.is_tpu(): # noqa: SIM102
if (not self._report_tpu_inference_usage()) and (
not self._report_torch_xla_usage()
):
if not self._report_tpu_inference_usage():
logger.exception("Failed to collect TPU information")
self.provider = _detect_cloud_provider()
self.architecture = platform.machine()

View File

@ -1,436 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionType,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv, next_power_of_2
logger = init_logger(__name__)
# TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT = 128
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
# from to fp32 directly. That's why it has a dtype mapping different from GPU
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.float8_e4m3fn,
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
"int8": torch.int8,
"uint8": torch.uint8,
}
try:
import tpu_inference # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
@requires_jax
def kv_cache_update_op_impl(
kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int,
):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update,
(kv, slot_mapping, kv_cache, num_kv_update_slices),
{"page_size": page_size, "num_slices_per_block": num_slices_per_block},
)
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping,"
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size,"
"int num_slices_per_block)"
"-> Tensor",
)
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(
kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int,
) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(
kv,
slot_mapping,
kv_cache,
num_kv_update_slices,
page_size,
num_slices_per_block,
)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(
kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int,
) -> torch.Tensor:
return kv_cache
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS"
@staticmethod
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
padded_head_size = (
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@staticmethod
def get_min_page_size(vllm_config: VllmConfig) -> int:
max_num_page_per_req = (
1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4
)
min_page_size = cdiv(
vllm_config.model_config.max_model_len, max_num_page_per_req
)
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size
@staticmethod
def get_max_num_seqs(model_len: int, page_size: int) -> int:
num_page_per_req = cdiv(model_len, page_size)
return 1024 * 1024 // 2 // num_page_per_req // 4
# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
@staticmethod
def get_page_size(vllm_config: VllmConfig) -> int:
# TODO: This is a temporary fix for vmem OOM.
# For long model length, we use 16 page-size to avoid too much
# VMEM spill. A more robust solution should be implemented to
# handle VREG spills.
if vllm_config.model_config.max_model_len > 8192:
return 16
page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16
if page_size <= 16:
return 16
if page_size >= 256:
return 256
return page_size
@dataclass
class PallasMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Used in the PallasAttentionBackendImpl
slot_mapping: torch.Tensor
block_tables: torch.Tensor
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
num_kv_update_slices: torch.Tensor
num_slices_per_kv_cache_update_block: int
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
self.kv_cache_quantized_dtype = None
if kv_cache_dtype != "auto":
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
kv_cache_dtype.lower().strip()
)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: PallasMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: shape =
[num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionBackendImpl"
)
# For determine_available_memory case.
if kv_cache.numel() == 0:
if output is None:
output = torch.ones_like(query)
return output
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
padded_head_size = (
cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
query = torch.nn.functional.pad(
query, (0, padded_head_size - self.head_size), value=0.0
)
key = torch.nn.functional.pad(
key, (0, padded_head_size - self.head_size), value=0.0
)
value = torch.nn.functional.pad(
value, (0, padded_head_size - self.head_size), value=0.0
)
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(
key,
value,
kv_cache,
slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block,
attn_metadata.num_kv_update_slices,
self.kv_cache_quantized_dtype,
layer._k_scale_float,
layer._v_scale_float,
)
if self.kv_cache_quantized_dtype is not None and (
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0
):
raise ValueError("k_scale_float and v_scale_float must be non-zero")
output = torch.ops.xla.ragged_paged_attention(
query,
kv_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=self.scale,
sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
output = output[:, :, : self.head_size]
return output.reshape(num_tokens, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int,
num_kv_update_slices: torch.Tensor,
kv_cache_quantized_dtype: torch.dtype | None = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
) -> None:
"""Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int
"""
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
if kv_cache_quantized_dtype is not None:
dtype_info = torch.finfo(kv_cache_quantized_dtype)
key = key.to(torch.float32) / k_scale
# NOTE: clamp is added here to avoid out of range of quantized dtype
key = torch.clamp(key, dtype_info.min, dtype_info.max)
key = key.to(kv_cache_quantized_dtype)
value = value.to(torch.float32) / v_scale
value = torch.clamp(value, dtype_info.min, dtype_info.max)
value = value.to(kv_cache_quantized_dtype)
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size)
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
kv_cache = kv_cache.flatten(0, 1)
new_kv_cache = torch.ops.xla.kv_cache_update_op(
kv,
slot_mapping,
kv_cache,
num_kv_update_slices,
page_size,
num_slices_per_kv_cache_update_block,
)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache.copy_(new_kv_cache)
# We can move this function to a common utils file if it's also useful for other
# hardware.
def dtype_bits(dtype: torch.dtype):
if dtype.is_floating_point:
try:
return torch.finfo(dtype).bits
except TypeError:
pass
elif dtype.is_complex:
if dtype is torch.complex32:
return 32
elif dtype is torch.complex64:
return 64
elif dtype is torch.complex128:
return 128
else:
try:
return torch.iinfo(dtype).bits
# torch.iinfo cannot support int4, int2, bits8...
except TypeError:
pass
str_dtype = str(dtype)
# support torch.int4, torch.int5, torch.uint5...
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
return int(str_dtype[-1])
raise TypeError(f"Getting the bit width of {dtype} is not supported")
def get_dtype_packing(dtype):
bits = dtype_bits(dtype)
if 32 % bits != 0:
raise ValueError(
f"The bit width must be divisible by 32, but got bits={bits}, "
"dtype={dtype}"
)
return 32 // bits
def get_page_size_bytes(
block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype
) -> int:
"""Returns the size in bytes of one page of the KV cache."""
padded_head_size = (
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
num_combined_kv_heads = num_kv_heads * 2
# NOTE: for the implicit padding in XLA
packing = get_dtype_packing(kv_cache_dtype)
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
return (
block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8
)

File diff suppressed because it is too large Load Diff

View File

@ -2,350 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
import os
from collections.abc import Callable
from typing import Any, TypeVar
from typing import TypeVar
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__)
_R = TypeVar("_R")
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
class TPUWorker:
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
self.is_driver_worker = is_driver_worker
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.use_spmd = envs.VLLM_XLA_USE_SPMD
self.original_parallel_config = None
if self.use_spmd:
# Under SPMD mode, distributed env is initialized as if there is
# only one worker/device.
self.original_parallel_config = self.parallel_config
self.parallel_config.tensor_parallel_size = 1
self.parallel_config.pipeline_parallel_size = 1
self.parallel_config.world_size = 1
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
# Delay profiler initialization to the start of the profiling.
# This is because in vLLM V1, MP runtime is initialized before the
# TPU Worker is initialized. The profiler server needs to start after
# MP runtime is initialized.
self.profiler = None
self.profile_dir = None
if vllm_config.profiler_config.profiler == "torch" and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = vllm_config.profiler_config.torch_profiler_dir
logger.info(
"Profiling enabled. Traces will be saved to: %s", self.profile_dir
)
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os.environ["LIBTPU_INIT_ARGS"] = (
os.environ.get("LIBTPU_INIT_ARGS", "")
+ " --xla_tpu_force_1d_allreduce_at_chunk_count=1"
" --xla_jf_conv_input_fusion=False"
)
# --xla_jf_conv_input_fusion=False is used to improve the perf of
# quantized matmul.
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
# Initialize the distributed environment.
self._init_tpu_worker_distributed_environment(
self.vllm_config, self.rank, self.distributed_init_method, self.local_rank
)
# Device initialization should happen after initializing
# the distributed runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed.
set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# TODO (NickLucche) On gsm we compile 80+ graphs.
# Re-evaluate limit, with MM we may get close to this limit.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
# Consequently, changes in optimization flags, which affect compilation
# results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if envs.VLLM_XLA_CACHE_PATH:
per_rank_path = os.path.join(
envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}"
)
xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = TPUModelRunner(
self.vllm_config, self.device, self.original_parallel_config
)
if rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
def determine_available_memory(self) -> int:
kv_caches: dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, AttentionSpec):
dtype = layer_spec.dtype
# Use an empty tensor instead of `None` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
kv_caches[layer_name] = tpu_kv_cache
else:
raise NotImplementedError(
f"Unsupported KV cache spec '{type(layer_spec)}'"
)
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches,
)
# `max_num_tokens >= max_num_batched_tokens` due to padding.
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
self.model_runner.profile_run(self.model_runner.max_num_tokens)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
with set_current_vllm_config(self.vllm_config):
self.model_runner.reset_dynamo_cache()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
if self.use_spmd:
# This is a workaround for the TPU SPMD mode. The get_memory_info
# API doesn't work with SPMD mode in PyTorch/XLA.
# TODO: use xm.get_memory_info for SPMD once it's supported in
# PyTorch/XLA.
import tpu_info
chip_type, _ = tpu_info.device.get_local_chips()
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
total_memory_size = device_usage[0].total_memory
current_mem = device_usage[0].memory_usage
else:
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
current_mem = m["bytes_used"]
# Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and
# there is no way to reset peak memory in XLA, So we
# use the heuristic of 2% of weights.
profiled = current_mem * 1.02
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(
total_memory_size * self.cache_config.gpu_memory_utilization
)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
head_size = self.model_config.get_head_size()
if head_size > 0:
padded_head_size = (
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
if padded_head_size != head_size:
logger.warning_once("head size is padded to %d", padded_head_size)
# We adjust the usable memory size for the KV cache to prevent OOM
# errors, even after padding the head_size.
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
return int(tpu_kv_cache_bytes)
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
def execute_model(
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
return self.model_runner.execute_model(scheduler_output)
def profile(self, is_start: bool = True):
if self.rank < 1:
if self.profile_dir is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
if self.profiler is None:
self.profiler = xp.start_server(9012)
xp.start_trace(self.profile_dir)
else:
xp.stop_trace()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def load_model(self) -> None:
self.model_runner.load_model()
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
def reload_weights(self) -> None:
self.model_runner.reload_weights()
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
self.model_runner.initialize_kv_cache(kv_cache_config)
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def _init_tpu_worker_distributed_environment(
self,
vllm_config: VllmConfig,
rank: int,
distributed_init_method: str | None = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
if self.use_spmd:
xr.use_spmd()
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
parallel_config = vllm_config.parallel_config
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method or "env://",
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
)
ensure_kv_transfer_initialized(vllm_config)
def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown()
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
"""Apply a function on the model inside this worker."""
return fn(self.get_model())
# TODO(weiyulin) Remove this file after adding an official way to use hardware plugin
if USE_TPU_INFERENCE:
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker