diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 6c02dcb76bec2..11c6e488f958f 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -92,7 +92,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,batched | 3 / N/A | 3 / N/A | silu,swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
-| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
diff --git a/tests/tpu/__init__.py b/tests/tpu/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tpu/lora/__init__.py b/tests/tpu/lora/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tpu/lora/test_lora.py b/tests/tpu/lora/test_lora.py
deleted file mode 100644
index 9780092b25e66..0000000000000
--- a/tests/tpu/lora/test_lora.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import pytest
-from torch_xla._internal import tpu
-
-import vllm
-from vllm.lora.request import LoRARequest
-
-# This file contains tests to ensure that LoRA works correctly on the TPU
-# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
-# for this. The adapters are:
-# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
-# from 1 to 4.
-
-# These adapters are trained using a standard huggingface peft training script,
-# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
-# 100 training iterations with a training batch size of 100.
-
-
-def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
- return vllm.LLM(
- model="Qwen/Qwen2.5-3B-Instruct",
- max_model_len=256,
- max_num_seqs=8,
- tensor_parallel_size=tp,
- enable_lora=True,
- max_loras=num_loras,
- max_lora_rank=8,
- )
-
-
-TPU_TENSOR_PARALLEL_SIZES = (
- [1, tpu.num_available_chips()] if tpu.num_available_chips() > 1 else [1]
-)
-
-
-@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
-def test_single_lora(tp: int):
- """
- This test ensures we can run a single LoRA adapter on the TPU backend.
- We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
- will force Qwen2.5-3B-Instruct to claim 1+1=1.
- """
-
- llm = setup_vllm(1, tp)
-
- prompt = "What is 1+1? \n"
-
- lora_request = LoRARequest(
- "lora_adapter_1",
- 1,
- "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter",
- )
- output = (
- llm.generate(
- prompt,
- sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
- lora_request=lora_request,
- )[0]
- .outputs[0]
- .text
- )
-
- answer = output.strip()[0]
-
- assert answer.isdigit()
- assert int(answer) == 1
-
-
-@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
-def test_lora_hotswapping(tp: int):
- """
- This test ensures we can run multiple LoRA adapters on the TPU backend, even
- if we only have space to store 1.
-
- We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
- will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
- """
-
- lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
- lora_requests = [
- LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
- for i in range(1, 5)
- ]
-
- llm = setup_vllm(1, tp)
-
- prompt = "What is 1+1? \n"
-
- for i, req in enumerate(lora_requests):
- output = (
- llm.generate(
- prompt,
- sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
- lora_request=req,
- )[0]
- .outputs[0]
- .text
- )
- answer = output.strip()[0]
-
- assert answer.isdigit()
- assert int(answer) == i + 1
-
-
-@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
-def test_multi_lora(tp: int):
- """
- This test ensures we can run multiple LoRA adapters on the TPU backend, when
- we have enough space to store all of them.
-
- We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
- will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
- """
- lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
- lora_requests = [
- LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
- for i in range(1, 5)
- ]
-
- llm = setup_vllm(4, tp)
-
- prompt = "What is 1+1? \n"
-
- for i, req in enumerate(lora_requests):
- output = (
- llm.generate(
- prompt,
- sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
- lora_request=req,
- )[0]
- .outputs[0]
- .text
- )
-
- answer = output.strip()[0]
-
- assert answer.isdigit()
- assert int(output.strip()[0]) == i + 1
diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py
deleted file mode 100644
index 5acfa484f0c13..0000000000000
--- a/tests/tpu/test_compilation.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import glob
-import os
-import tempfile
-
-import depyf
-
-
-def test_tpu_compilation():
- temp_dir = tempfile.mkdtemp()
- with depyf.prepare_debug(temp_dir):
- from vllm import LLM, SamplingParams
-
- prompts = [
- "A robot may not injure a human being",
- "It is only with the heart that one can see rightly;",
- "The greatest glory in living lies not in never falling,",
- ]
- answers = [
- " or, through inaction",
- " what is essential ",
- " but in rising ",
- ]
-
- # Currently, top-p sampling is disabled. `top_p` should be 1.0.
- N = 1
- sampling_params = SamplingParams(temperature=0.7, top_p=1.0, n=N, max_tokens=16)
-
- llm = LLM(
- model="Qwen/Qwen2-1.5B-Instruct",
- max_num_batched_tokens=256,
- max_model_len=256,
- max_num_seqs=32,
- enforce_eager=False,
- )
-
- outputs = llm.generate(prompts, sampling_params)
- for output, answer in zip(outputs, answers):
- prompt = output.prompt
- generated_text = output.outputs[0].text
- print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
- assert generated_text.startswith(answer)
-
- compiled_codes = sorted(
- glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py"))
- )
-
- for i, compiled_code in enumerate(compiled_codes):
- print("{} file: {}".format(i + 1, compiled_code))
-
- # We should only trigger Dynamo compilation 2 times:
- # 1. Forward pass without kv_caches
- # 2. Forward pass with kv_caches
- # Check we have 2 compiled codes
- assert len(compiled_codes) == 2
-
- kv_cache_prefix = "kv_cache"
- attn_prefix = "ragged_paged_attention"
-
- def extract_compiled_index(s):
- parts = s.replace(".", "_").split("_")
- numbers = [int(part) for part in parts if part.isdigit()]
- return numbers[0]
-
- # Check all the compilations are as expected. The dump files include the
- # captured graph for the forward function of the nn.Module.
- compiled_fns = sorted(
- glob.glob(os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")),
- key=lambda s: extract_compiled_index(s),
- )
-
- for i, compiled_fn in enumerate(compiled_fns):
- print("{} file: {}".format(i + 1, compiled_fn))
-
- # The first compilation should not have any kv_caches
- with open(compiled_fns[0]) as f:
- content = f.read()
- assert kv_cache_prefix not in content
-
- # The second compilation should have kv_caches and the
- # ragged_paged_attention
- with open(compiled_fns[1]) as f:
- content = f.read()
- assert kv_cache_prefix in content and attn_prefix in content
diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py
deleted file mode 100644
index cf455ff3edbd3..0000000000000
--- a/tests/tpu/test_custom_dispatcher.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import pytest
-
-from vllm.config import CompilationMode
-
-from ..utils import compare_two_settings
-
-# --enforce-eager on TPU causes graph compilation
-# this times out default Health Check in the MQLLMEngine,
-# so we set the timeout here to 30s
-
-
-def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch):
- with monkeypatch.context() as m:
- m.setenv("VLLM_RPC_TIMEOUT", "30000")
- compare_two_settings(
- "Qwen/Qwen2.5-1.5B-Instruct",
- arg1=[
- "--max-model-len=256",
- "--max-num-seqs=32",
- "--enforce-eager",
- f"-O{CompilationMode.DYNAMO_TRACE_ONCE}",
- ],
- arg2=[
- "--max-model-len=256",
- "--max-num-seqs=32",
- "--enforce-eager",
- f"-O{CompilationMode.STOCK_TORCH_COMPILE}",
- ],
- env1={},
- env2={},
- )
diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py
deleted file mode 100644
index e3236d20bf673..0000000000000
--- a/tests/tpu/test_moe_pallas.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests for the Pallas MOE implementation.
-
-Run `pytest tests/kernels/moe/test_moe_pallas.py`.
-"""
-
-import pytest
-import torch
-import torch_xla
-
-from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe as pallas_moe
-from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
- fused_moe as torch_moe,
-)
-from vllm.platforms import current_platform
-
-if not current_platform.is_tpu():
- pytest.skip("This test needs a TPU.", allow_module_level=True)
-
-NUM_EXPERTS = [8, 64]
-EP_SIZE = [1]
-TOP_KS = [2, 6]
-
-
-# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
-@pytest.mark.parametrize("m", [8, 16, 64, 2048])
-@pytest.mark.parametrize("n", [128, 1024, 2048])
-@pytest.mark.parametrize("k", [128, 511, 1024])
-@pytest.mark.parametrize("e", NUM_EXPERTS)
-@pytest.mark.parametrize("topk", TOP_KS)
-@pytest.mark.parametrize("ep_size", EP_SIZE)
-@pytest.mark.parametrize("dtype", [torch.bfloat16])
-def test_pallas_moe(
- m: int,
- n: int,
- k: int,
- e: int,
- topk: int,
- ep_size: int,
- dtype: torch.dtype,
-):
- import torch_xla.core.xla_model as xm
-
- with torch.device(xm.xla_device()):
- a = torch.randn((m, k), dtype=dtype) / 10
- w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
- w2 = torch.randn((e, k, n), dtype=dtype) / 10
-
- score = torch.randn((m, e), dtype=dtype)
-
- # TODO: Support ep
- if ep_size > 1:
- pytest.skip("No support for ep_size > 1 yet")
- else:
- e_map = None
-
- # Run both implementations
- torch_output = torch_moe(
- hidden_states=a,
- w1=w1,
- w2=w2,
- gating_output=score,
- topk=topk,
- global_num_experts=e,
- expert_map=e_map,
- renormalize=False,
- )
-
- pallas_output = pallas_moe(
- hidden_states=a,
- w1=w1,
- w2=w2,
- gating_output=score,
- topk=topk,
- global_num_experts=e,
- expert_map=e_map,
- renormalize=False,
- )
- torch_xla.sync(wait=False)
-
- # Compare outputs
- torch.testing.assert_close(
- pallas_output.cpu(),
- torch_output.cpu(),
- atol=2e-2,
- rtol=0,
- )
diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py
deleted file mode 100644
index 151be5f17fe89..0000000000000
--- a/tests/tpu/test_quantization_accuracy.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from dataclasses import dataclass
-
-import lm_eval
-import pytest
-
-TASK = "gsm8k"
-FILTER = "exact_match,strict-match"
-RTOL = 0.03
-
-
-@dataclass
-class GSM8KAccuracyTestConfig:
- model_name: str
- expected_value: float
-
- def get_model_args(self) -> str:
- return f"pretrained={self.model_name},max_model_len=4096,max_num_seqs=32"
-
-
-# NOTE: Accuracy scores measured on GPUs.
-ACCURACY_CONFIGS = [
- GSM8KAccuracyTestConfig(
- model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
- expected_value=0.76,
- ), # no bias
- # NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
- # so only one of these tests can run in a single call to pytest. As
- # a follow-up, move this into the LM-EVAL section of the CI.
- # GSM8KAccuracyTestConfig(
- # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
- # expected_value=0.66), # bias in QKV layers
-]
-
-
-@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
-def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
- results = lm_eval.simple_evaluate(
- model="vllm",
- model_args=config.get_model_args(),
- tasks="gsm8k",
- batch_size="auto",
- )
-
- EXPECTED_VALUE = config.expected_value
- measured_value = results["results"][TASK][FILTER]
- assert (
- measured_value - RTOL < EXPECTED_VALUE
- and measured_value + RTOL > EXPECTED_VALUE
- ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
diff --git a/tests/v1/tpu/__init__.py b/tests/v1/tpu/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py
deleted file mode 100644
index 0d53a02476fab..0000000000000
--- a/tests/v1/tpu/test_basic.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""A basic correctness check for TPUs
-
-Run `pytest tests/v1/tpu/test_basic.py`.
-"""
-
-from typing import TYPE_CHECKING
-
-import pytest
-from torch_xla._internal import tpu
-
-from vllm.platforms import current_platform
-
-if TYPE_CHECKING:
- from tests.conftest import VllmRunner
-else:
- VllmRunner = object
-
-MODELS = [
- "Qwen/Qwen2.5-1.5B-Instruct",
- # TODO: Enable this model when fixed.
- # "Qwen/Qwen1.5-MoE-A2.7B",
- # TODO: Enable this models with v6e
- # "Qwen/Qwen2-7B-Instruct",
- # "meta-llama/Llama-3.1-8B",
-]
-
-TENSOR_PARALLEL_SIZES = [1]
-MAX_NUM_REQS = [16, 1024]
-
-# TODO: Enable when CI/CD will have a multi-tpu instance
-# TENSOR_PARALLEL_SIZES = [1, 4]
-
-
-@pytest.mark.skipif(
- not current_platform.is_tpu(), reason="This is a basic test for TPU only"
-)
-@pytest.mark.parametrize("model", MODELS)
-@pytest.mark.parametrize("max_tokens", [5])
-@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
-@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
-def test_basic(
- vllm_runner: type[VllmRunner],
- model: str,
- max_tokens: int,
- tensor_parallel_size: int,
- max_num_seqs: int,
-) -> None:
- prompt = (
- "The next numbers of the sequence "
- + ", ".join(str(i) for i in range(1024))
- + " are:"
- )
- example_prompts = [prompt]
-
- with vllm_runner(
- model,
- # Note: max_num_batched_tokens == 1024 is needed here to
- # actually test chunked prompt
- max_num_batched_tokens=1024,
- max_model_len=8192,
- gpu_memory_utilization=0.7,
- max_num_seqs=max_num_seqs,
- tensor_parallel_size=tensor_parallel_size,
- ) as vllm_model:
- vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
- output = vllm_outputs[0][1]
-
- assert "1024" in output or "0, 1" in output
-
-
-@pytest.mark.skip(reason="Temporarily disabled due to timeout")
-@pytest.mark.skipif(
- not current_platform.is_tpu(), reason="This is a basic test for TPU only"
-)
-@pytest.mark.parametrize("max_tokens", [8])
-@pytest.mark.parametrize("max_num_seqs", [16])
-def test_phi3(
- vllm_runner: type[VllmRunner],
- max_tokens: int,
- max_num_seqs: int,
-) -> None:
- prompts = [
- "A robot may not injure a human being",
- "It is only with the heart that one can see rightly;",
- "The greatest glory in living lies not in never falling,",
- ]
- answers = [
- " or, by violating privacy",
- " what is essential is love.",
- " but in rising every time we fall.",
- ]
- # test head dim = 96
- model = "microsoft/Phi-3-mini-128k-instruct"
-
- with vllm_runner(
- model, max_num_batched_tokens=256, max_num_seqs=max_num_seqs
- ) as vllm_model:
- vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
- # vllm_outputs is a list of tuples whose first element is the token id
- # and the second element is the output (including the prompt).
- for output, answer in zip(vllm_outputs, answers):
- generated_text = output[1]
- assert answer in generated_text
-
-
-TP_SIZE_8 = 8
-
-
-@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
-@pytest.mark.skipif(
- tpu.num_available_chips() < TP_SIZE_8,
- reason=f"This test requires {TP_SIZE_8} TPU chips.",
-)
-def test_gemma3_27b_with_text_input_and_tp(
- vllm_runner: type[VllmRunner],
-) -> None:
- model = "google/gemma-3-27b-it"
- max_tokens = 16
- tensor_parallel_size = TP_SIZE_8
- max_num_seqs = 4
- prompts = [
- "A robot may not injure a human being",
- "It is only with the heart that one can see rightly;",
- "The greatest glory in living lies not in never falling,",
- ]
- answers = [
- " or, through inaction, allow a human being to come to harm.",
- " what is essential is invisible to the eye.",
- " but in rising every time we fall.",
- ]
-
- with vllm_runner(
- model,
- max_num_batched_tokens=256,
- max_num_seqs=max_num_seqs,
- tensor_parallel_size=tensor_parallel_size,
- ) as vllm_model:
- vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
- # vllm_outputs is a list of tuples whose first element is the token id
- # and the second element is the output (including the prompt).
- for output, answer in zip(vllm_outputs, answers):
- generated_text = output[1]
- assert answer in generated_text
-
-
-@pytest.mark.skipif(
- not current_platform.is_tpu(), reason="This is a basic test for TPU only"
-)
-def test_w8a8_quantization(
- vllm_runner: type[VllmRunner],
-) -> None:
- model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
- max_tokens = 5
- tensor_parallel_size = 1
- max_num_seqs = 4
-
- prompt = (
- "The next numbers of the sequence "
- + ", ".join(str(i) for i in range(1024))
- + " are:"
- )
- example_prompts = [prompt]
-
- with vllm_runner(
- model,
- max_num_batched_tokens=64,
- max_model_len=4096,
- gpu_memory_utilization=0.7,
- max_num_seqs=max_num_seqs,
- tensor_parallel_size=tensor_parallel_size,
- ) as vllm_model:
- vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
- output = vllm_outputs[0][1]
-
- assert "1024" in output or "0, 1" in output
diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py
deleted file mode 100644
index 99d5f98351ad2..0000000000000
--- a/tests/v1/tpu/test_kv_cache_update_kernel.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import numpy as np
-import pytest
-import torch
-import torch_xla
-
-import vllm.v1.attention.backends.pallas # noqa: F401
-from vllm.platforms import current_platform
-
-
-@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
-@pytest.mark.parametrize("page_size", [32, 33])
-@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
-@pytest.mark.parametrize("head_dim", [128, 256])
-@pytest.mark.parametrize("num_slices_per_block", [4, 8])
-def test_kv_cache_update_kernel(
- page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int
-):
- page_num = 1000
- padded_num_tokens = 128
- kv_cache_cpu = torch.zeros(
- (page_num * page_size, combined_kv_head_num, head_dim),
- dtype=torch.bfloat16,
- device="cpu",
- )
- kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
- new_kv_cpu = torch.randn(
- (padded_num_tokens, combined_kv_head_num, head_dim),
- dtype=torch.bfloat16,
- device="cpu",
- )
- new_kv_xla = new_kv_cpu.to(torch_xla.device())
- slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], dtype=np.int32)
- num_kv_update_slices = len(slice_lens)
- kv_cache_start_indices = np.array(
- [
- page_size * 2 - 7,
- page_size * 2,
- page_size * 3,
- page_size * 4 + 6,
- page_size * 5 + 7,
- page_size * 6 + 8,
- page_size * 15 + 3,
- ],
- dtype=np.int32,
- )
- new_kv_cache_indices = np.concatenate(
- [np.array([0], dtype=np.int32), np.cumsum(slice_lens[:-1])]
- )
- slot_mapping = np.stack(
- [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1
- )
- slot_mapping = np.transpose(slot_mapping)
- slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", dtype=torch.int32)
- slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
- num_kv_update_slices_xla = torch.tensor(
- [num_kv_update_slices], device=torch_xla.device(), dtype=torch.int32
- )
- torch_xla.sync()
-
- torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
- new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
- new_kv_xla,
- slot_mapping_xla,
- kv_cache_xla,
- num_kv_update_slices_xla,
- page_size,
- num_slices_per_block,
- )
- kv_cache_xla.copy_(new_kv_cache_xla)
- torch_xla.sync()
-
- for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, slice_lens):
- kv_cache_cpu[ci : ci + sl, :, :] = new_kv_cpu[ni : ni + sl, :, :]
-
- assert torch.allclose(kv_cache_xla.cpu(), kv_cache_cpu, atol=1e-4, rtol=1e-4)
diff --git a/tests/v1/tpu/test_mha_attn.py b/tests/v1/tpu/test_mha_attn.py
deleted file mode 100644
index 84968dee6b60c..0000000000000
--- a/tests/v1/tpu/test_mha_attn.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-Test:
-
-* Tests for MMEncoderAttention layer
-"""
-
-import pytest
-import torch
-import torch_xla
-import torch_xla.core
-import torch_xla.core.xla_model
-
-from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
-from vllm.attention.selector import _cached_get_attn_backend
-from vllm.platforms import current_platform
-
-
-@pytest.fixture(autouse=True)
-def clear_cache():
- """Clear lru cache to ensure each test case runs without caching."""
- _cached_get_attn_backend.cache_clear()
-
-
-def ref_attention(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- scale: float,
-) -> torch.Tensor:
- """
- Native implementation of scaled dot product attention without mask:
- - query, key, value: [batch_size, seq_len, num_heads, head_size]
- - attn_mask: [batch_size, seq_len, seq_len]
- """
- query, key, value = (x.transpose(1, 2) for x in (query, key, value))
- attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
- attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
- out = torch.matmul(attn_weights, value).transpose(1, 2)
- return out
-
-
-BATCH_SIZES = [1, 16]
-SEQ_LENS = [1]
-NUM_HEADS = [1, 16]
-NUM_KV_HEADS = [1]
-HEAD_SIZES = [64, 80]
-
-
-@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
-@pytest.mark.parametrize("batch_size", BATCH_SIZES)
-@pytest.mark.parametrize("seq_len", SEQ_LENS)
-@pytest.mark.parametrize("num_heads", NUM_HEADS)
-@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
-@pytest.mark.parametrize("head_size", HEAD_SIZES)
-@pytest.mark.parametrize("device", [torch_xla.core.xla_model.xla_device()])
-def test_mha_attn_forward(
- batch_size: int,
- seq_len: int,
- num_heads: int,
- num_kv_heads: int,
- head_size: int,
- device: str,
-):
- current_platform.seed_everything(0)
- # These are expected to be f32
- q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device)
- k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
- v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
- scale = 1.0 / head_size**0.5
- attn = MMEncoderAttention(
- num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
- )
- output = attn(q, k, v)
-
- assert num_heads % num_kv_heads == 0
- num_queries_per_kv = num_heads // num_kv_heads
-
- q = q.reshape(batch_size, seq_len, num_heads, head_size)
- k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
- v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
- if num_queries_per_kv > 1:
- k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
- v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
-
- ref_output = ref_attention(
- q,
- k,
- v,
- scale=scale,
- ).reshape(batch_size, seq_len, num_heads * head_size)
- # torch_xla flash_attn kernel is less accurate but much faster
- torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-3)
diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py
deleted file mode 100644
index 3caa7c14b393b..0000000000000
--- a/tests/v1/tpu/test_multimodal.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import openai
-import pytest
-
-from vllm.multimodal.utils import encode_image_url
-from vllm.platforms import current_platform
-
-from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS
-from ...utils import RemoteOpenAIServer
-
-
-@pytest.fixture(scope="session")
-def url_encoded_image(local_asset_server) -> dict[str, str]:
- return {
- image_asset: encode_image_url(local_asset_server.get_image_asset(image_asset))
- for image_asset in TEST_IMAGE_ASSETS
- }
-
-
-@pytest.mark.asyncio
-@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
-@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
-async def test_basic_vision(model_name: str, url_encoded_image: dict[str, str]):
- pytest.skip("Skip this test until it's fixed.")
-
- def whats_in_this_image_msg(url):
- return [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "What's in this image?"},
- {"type": "image_url", "image_url": {"url": url}},
- ],
- }
- ]
-
- server_args = [
- "--max-model-len",
- "1024",
- "--max-num-seqs",
- "16",
- "--gpu-memory-utilization",
- "0.95",
- "--trust-remote-code",
- "--max-num-batched-tokens",
- "576",
- # NOTE: max-num-batched-tokens>=mm_item_size
- "--disable_chunked_mm_input",
- ]
-
- # Server will pre-compile on first startup (takes a long time).
- with RemoteOpenAIServer(
- model_name, server_args, max_wait_seconds=600
- ) as remote_server:
- client: openai.AsyncOpenAI = remote_server.get_async_client()
-
- # Other requests now should be much faster
- for image_url in TEST_IMAGE_ASSETS:
- image_url = url_encoded_image[image_url]
- chat_completion_from_url = await client.chat.completions.create(
- model=model_name,
- messages=whats_in_this_image_msg(image_url),
- max_completion_tokens=24,
- temperature=0.0,
- )
- result = chat_completion_from_url
- assert result
- choice = result.choices[0]
- assert choice.finish_reason == "length"
-
- message = choice.message
- message = result.choices[0].message
- assert message.content is not None and len(message.content) >= 10
- assert message.role == "assistant"
diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py
deleted file mode 100644
index 0a994e99bade1..0000000000000
--- a/tests/v1/tpu/test_pallas.py
+++ /dev/null
@@ -1,100 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from unittest.mock import ANY, patch
-
-import torch
-
-from vllm.attention.backends.abstract import AttentionType
-from vllm.v1.attention.backends.pallas import PallasAttentionBackendImpl, PallasMetadata
-
-
-def test_ragged_paged_attention():
- # We verify that the kernel inputs such as sliding_window, etc. are passed
- # in from the model correctly.
- # The correctness of the paged attention kernel is tested in the kernel
- # library.
- num_heads = 4
- head_size = 128
- scale = 1.0
- num_kv_heads = 4
- sliding_window = 128
- logits_soft_cap = 50.0
- attn_impl = PallasAttentionBackendImpl(
- num_heads=num_heads,
- head_size=head_size,
- scale=scale,
- num_kv_heads=num_kv_heads,
- alibi_slopes=None,
- sliding_window=sliding_window,
- kv_cache_dtype="auto",
- logits_soft_cap=logits_soft_cap,
- attn_type=AttentionType.DECODER,
- )
-
- class FakeAttentionLayer:
- _q_scale_float: float
- _k_scale_float: float
- _v_scale_float: float
-
- layer = FakeAttentionLayer()
- layer._q_scale_float = 1.0
- layer._k_scale_float = 1.0
- layer._v_scale_float = 1.0
-
- num_tokens = 16
- num_blocks = 1024
- block_size = 16
- query = torch.zeros(num_tokens, num_heads * head_size)
- key = torch.zeros(num_tokens, num_kv_heads * head_size)
- value = torch.zeros(num_tokens, num_kv_heads * head_size)
- kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
- slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
- max_num_reqs = 8
- max_num_blocks_per_req = 8
- num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32)
- block_tables = torch.zeros(
- (max_num_reqs, max_num_blocks_per_req), dtype=torch.int32
- )
- context_lens = torch.ones((max_num_reqs,), dtype=torch.int32)
- query_lens = [1] * max_num_reqs
- query_start_loc = torch.cumsum(
- torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32
- )
- num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32)
- attn_metadata = PallasMetadata(
- slot_mapping=slot_mapping,
- block_tables=block_tables,
- context_lens=context_lens,
- query_start_loc=query_start_loc,
- num_seqs=num_seqs,
- num_kv_update_slices=num_kv_update_slices,
- num_slices_per_kv_cache_update_block=8,
- )
-
- with patch("torch.ops.xla.ragged_paged_attention") as mock_ragged_paged_attention:
- attn_impl.forward(
- layer=layer,
- query=query,
- key=key,
- value=value,
- kv_cache=kv_cache,
- attn_metadata=attn_metadata,
- )
-
- mock_ragged_paged_attention.assert_called_once_with(
- ANY, # query
- ANY, # kv_cache
- ANY, # context_lens
- ANY, # block_tables
- ANY, # query_start_loc
- ANY, # num_seqs
- num_kv_pages_per_block=None,
- num_queries_per_block=None,
- vmem_limit_bytes=None,
- use_kernel=True,
- sm_scale=scale,
- sliding_window=sliding_window,
- soft_cap=logits_soft_cap,
- k_scale=1.0,
- v_scale=1.0,
- )
diff --git a/tests/v1/tpu/test_perf.py b/tests/v1/tpu/test_perf.py
deleted file mode 100644
index e62b969fe3b95..0000000000000
--- a/tests/v1/tpu/test_perf.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""A basic performance regression test for TPUs
-
-Run `pytest tests/v1/tpu/test_perf.py`.
-"""
-
-import time
-from dataclasses import dataclass
-from typing import TYPE_CHECKING
-
-import numpy as np
-import pytest
-
-from vllm.platforms import current_platform
-from vllm.sampling_params import SamplingParams
-from vllm.tokenizers import get_tokenizer
-
-if TYPE_CHECKING:
- from tests.conftest import VllmRunner
-else:
- VllmRunner = object
-
-
-@dataclass
-class TestParams:
- model: str
- num_prompts: int
- prefix_len: int
- decode_len: int
- expected_avg_time: float
- err_tol: float
-
-
-TEST_PARAMS = [
- # TODO: Cannot run a series of tests because:
- # RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed:
- # open(/dev/vfio/0): Device or resource busy: Device or resource busy;
- # Couldn't open iommu group /dev/vfio/0
- # => Investigate
- # TestParams(
- # model="Qwen/Qwen2.5-1.5B-Instruct",
- # num_prompts=1,
- # prefix_len=10,
- # decode_len=5,
- # expected_avg_time=0.03,
- # err_tol=0.01,
- # ),
- # TestParams(
- # model="Qwen/Qwen2.5-1.5B-Instruct",
- # num_prompts=10,
- # prefix_len=100,
- # decode_len=50,
- # expected_avg_time=0.234,
- # err_tol=0.020,
- # ),
- TestParams(
- model="Qwen/Qwen2.5-1.5B-Instruct",
- num_prompts=64,
- prefix_len=500,
- decode_len=50,
- # commit id: ccb246776d93ef105904a8ec015b3587240a1183
- # tpu: v5lite (old vllm CI/CD)
- # expected_avg_time=1.4,
- # err_tol=0.30,
- # (This is the active CI/CD instance)
- # commit id: ccb246776d93ef105904a8ec015b3587240a1183
- # tpu: v6e (current vllm CI/CD)
- expected_avg_time=1.7, # measured with VLLM_XLA_CACHE_PATH=
- err_tol=0.20,
- ),
-]
-
-NUM_WARMUPS = 5
-NUM_RUNS = 10
-
-MAX_MODEL_LEN = 1024
-MAX_NUM_SEQS = 32
-GPU_UTIL = 0.9
-
-
-@pytest.mark.skipif(
- not current_platform.is_tpu(),
- reason="This is a basic performance test for TPU only",
-)
-@pytest.mark.parametrize("params", TEST_PARAMS)
-def test_perf(
- vllm_runner: type[VllmRunner],
- params: TestParams,
-) -> None:
- tokenizer = get_tokenizer(
- params.model, tokenizer_mode="auto", trust_remote_code=True
- )
-
- prompts = []
- for i in range(params.num_prompts):
- prefix_token_ids = np.random.randint(
- 0, tokenizer.vocab_size, size=params.prefix_len
- ).tolist()
- prompt = tokenizer.decode(prefix_token_ids)
- prompts.append(prompt)
-
- print(
- "-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format(
- len(prompts), params.prefix_len, params.decode_len
- )
- )
-
- sampling_params = SamplingParams(
- max_tokens=params.decode_len, temperature=1.0, min_p=0.0
- )
-
- with vllm_runner(
- params.model,
- max_num_batched_tokens=MAX_MODEL_LEN,
- max_model_len=MAX_MODEL_LEN,
- max_num_seqs=MAX_NUM_SEQS,
- gpu_memory_utilization=GPU_UTIL,
- enforce_eager=False,
- tensor_parallel_size=1,
- ) as vllm_model:
- print(" -- Warmup / Compile")
- for i in range(NUM_WARMUPS):
- _ = vllm_model.generate(prompts, sampling_params)
-
- print(" -- Benchmarking... ")
- times = []
- for i in range(NUM_RUNS):
- start_time = time.time()
- _ = vllm_model.generate(prompts, sampling_params)
- times.append(time.time() - start_time)
-
- avg_time = sum(times) / len(times)
-
- print(" -- avg_time = {}".format(avg_time))
- print(
- " -- expected_avg_time = {} with err_tol = {}".format(
- params.expected_avg_time, params.err_tol
- )
- )
- diff = avg_time - params.expected_avg_time
- ok = diff < params.err_tol
- if diff < -params.err_tol:
- print(
- " !! WARNING !! Performance has improved by {}, "
- "it may be necessary to fine-tune the "
- "expected_avg_time = {}".format(-diff, params.expected_avg_time)
- )
-
- assert ok, " !! ERROR !! Regression detected"
diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py
deleted file mode 100644
index 58f6292b05a72..0000000000000
--- a/tests/v1/tpu/test_sampler.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import random
-
-import pytest
-
-from vllm import LLM
-from vllm.platforms import current_platform
-from vllm.sampling_params import SamplingParams
-
-
-@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
-@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
-def test_sampler_different(model_name: str):
- """
- Test significantly different sampling params to assert the model produces
- different results.
- """
- llm = LLM(
- model_name,
- enforce_eager=False,
- max_num_seqs=1,
- max_model_len=512,
- max_num_batched_tokens=256,
- )
- prompts = ["Write a short story about a robot that dreams for the first time."]
- sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
- output = llm.generate(prompts, sampling_params)
-
- sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
- output2 = llm.generate(prompts, sampling_params)
- assert output[0].outputs[0].text != output2[0].outputs[0].text
-
- with pytest.raises(ValueError):
- # Unsupported `seed` param.
- sampling_params = SamplingParams(temperature=0.3, seed=42)
- output2 = llm.generate(prompts, sampling_params)
-
- # Batch-case with TopK/P
- for B in [4, 16]:
- p = prompts * B
- sampling_params = [
- SamplingParams(
- temperature=0.1,
- min_p=0.8,
- max_tokens=64,
- # Vary number of ks
- top_k=random.randint(4, 12),
- top_p=random.random(),
- )
- for _ in range(B)
- ]
- # Make sure first two reqs have the same K/P
- sampling_params[0] = sampling_params[1]
- output = llm.generate(p, sampling_params)
- # There are natural numerical instabilities that make it difficult
- # to have deterministic results over many tokens, tests the first ~20
- # tokens match.
- assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
-
-
-@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
-# TODO TPU will appear busy if we fan-out test params here
-@pytest.mark.parametrize("n_prompts", [1])
-@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
-def test_logprobs(model_name: str, n_prompts: int):
- """
- Request top logprobs with different sampling settings and check
- that results contains the requested number, ordered ascendingly.
- """
-
- def check_num_logprobs(logprobs, expected_num: int):
- for step in logprobs:
- prev_logp = 1.0
- # order by rank
- sorted_step = dict(sorted(step.items(), key=lambda item: item[1].rank))
-
- # Can contain the sampled token
- assert len(step) == expected_num or len(step) == expected_num + 1
- # Check results are ordered by prob value
- for rankno, (tid, logp) in enumerate(sorted_step.items()):
- assert logp.logprob <= prev_logp
- prev_logp = logp.logprob
- assert logp.rank == rankno + 1
-
- llm = LLM(
- model_name,
- enforce_eager=False,
- max_num_seqs=1,
- max_model_len=128,
- max_num_batched_tokens=128,
- )
- prompts = [
- "Write a short story about a robot that dreams for the first time."
- ] * n_prompts
- greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64, logprobs=4)
- regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64, logprobs=4)
- topkp_sampling_params = SamplingParams(
- temperature=0.4, max_tokens=64, logprobs=4, top_k=12, top_p=0.5
- )
-
- for sp in [greedy_sampling_params, regular_sampling_params, topkp_sampling_params]:
- output = llm.generate(prompts, sp)
- for o in output:
- check_num_logprobs(o.outputs[0].logprobs, 4)
diff --git a/tests/v1/tpu/test_spmd_model_weight_loading.py b/tests/v1/tpu/test_spmd_model_weight_loading.py
deleted file mode 100644
index be866bf90a792..0000000000000
--- a/tests/v1/tpu/test_spmd_model_weight_loading.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import gc
-import tempfile
-
-import numpy as np
-import pytest
-import torch_xla.distributed.spmd as xs
-import torch_xla.runtime as xr
-
-from vllm.config import set_current_vllm_config
-from vllm.distributed.parallel_state import (
- ensure_model_parallel_initialized,
- init_distributed_environment,
-)
-from vllm.engine.arg_utils import EngineArgs
-from vllm.model_executor.model_loader.tpu import TPUModelLoader
-
-
-def _setup_environment(model):
- engine_args = EngineArgs(
- model=model,
- )
- vllm_config = engine_args.create_engine_config()
- with set_current_vllm_config(vllm_config):
- temp_file = tempfile.mkstemp()[1]
- init_distributed_environment(
- 1,
- 0,
- local_rank=0,
- distributed_init_method=f"file://{temp_file}",
- backend="gloo",
- )
- # Under single worker mode, full model is init first and then
- # partitioned using GSPMD.
- ensure_model_parallel_initialized(1, 1)
- return vllm_config
-
-
-MESH = None
-
-
-def _get_spmd_mesh():
- global MESH
- if MESH is None:
- xr.use_spmd()
- num_devices = xr.global_runtime_device_count()
- mesh_shape = (num_devices, 1)
- device_ids = np.array(range(num_devices))
- MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
- return MESH
-
-
-@pytest.mark.parametrize(
- "model",
- [
- "Qwen/Qwen2-1.5B-Instruct",
- # Skip large models due to CI runner disk space limitations
- # "meta-llama/Llama-3.1-8B-Instruct",
- # "meta-llama/Llama-3.1-70B-Instruct",
- ],
-)
-def test_tpu_model_loader(model):
- # Skip the 70B test if there are less than 8 chips
- # TODO: Query using torch xla API, the query API is not working
- # with SPMD now. However, This test is running under SPMD mode.
- if "70B" in model and xr.global_runtime_device_count() < 8:
- pytest.skip(
- "Skipping 70B model if the TPU VM has less than 8 chips to \
- avoid OOM."
- )
-
- vllm_config = _setup_environment(model)
- loader = TPUModelLoader(load_config=vllm_config.load_config)
- mesh = _get_spmd_mesh()
- model = loader.load_model(vllm_config, vllm_config.model_config, mesh)
- del model
- gc.collect()
diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py
deleted file mode 100644
index c6634395bb167..0000000000000
--- a/tests/v1/tpu/test_topk_topp_sampler.py
+++ /dev/null
@@ -1,149 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import math
-
-import pytest
-import torch
-import torch_xla
-
-from vllm.platforms import current_platform
-from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
-from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu
-
-if not current_platform.is_tpu():
- pytest.skip("This test needs a TPU.", allow_module_level=True)
-import torch_xla.core.xla_model as xm
-
-BATCH_SIZE = 1024
-VOCAB_SIZE = 128 * 1024
-TOLERANCE = 1e-6
-
-
-def test_topk_equivalence_to_native_impl():
- with torch.device(xm.xla_device()):
- xm.set_rng_state(seed=33)
-
- logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
-
- # Random top-k values between 1 and 10.
- k = torch.randint(1, 10, (BATCH_SIZE,))
-
- # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
- k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), VOCAB_SIZE)
-
- result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
-
- result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
- assert torch.allclose(result_native, result_tpu)
-
-
-def test_topp_result_sums_past_p():
- with torch.device(xm.xla_device()):
- xm.set_rng_state(seed=33)
-
- logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
- probs = logits.softmax(dim=-1)
-
- # Random top-p values between 0 and 1.
- p = torch.rand((BATCH_SIZE,))
-
- # Set p=1 for ~50% of requests in the batch (top-p disabled).
- p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), 1)
-
- no_op_k = torch.tensor([VOCAB_SIZE])
- logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), k=no_op_k, p=p)
-
- # Verify that the masked logit's probability sums to at least p.
- probs.masked_fill_(logits_masked.isinf(), 0)
- masked_prob_sum = probs.sum(dim=-1)
-
- torch_xla.sync()
-
- # Perform assertion on CPU.
- assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
-
-
-def test_topp_basic():
- with torch.device(xm.xla_device()):
- logits = torch.tensor(
- [
- [math.log(0.2), math.log(0.3), math.log(0.5)],
- [math.log(0.5), math.log(0.1), math.log(0.4)],
- ]
- )
-
- result = apply_top_k_top_p_tpu(
- logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79])
- )
-
- torch_xla.sync()
-
- # Expect the smallest elements to be dropped.
- expected_result = logits.clone().cpu()
- expected_result[0, 0] = float("-inf")
- expected_result[1, 1] = float("-inf")
- assert torch.allclose(expected_result, result.cpu())
-
-
-def test_topp_select_all():
- with torch.device(xm.xla_device()):
- logits = torch.tensor(
- [
- [math.log(0.2), math.log(0.3), math.log(0.5)],
- [math.log(0.5), math.log(0.1), math.log(0.4)],
- ]
- )
-
- result = apply_top_k_top_p_tpu(
- logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0])
- )
-
- torch_xla.sync()
-
- assert torch.allclose(logits.cpu(), result.cpu())
-
-
-def test_topp_with_ties():
- with torch.device(xm.xla_device()):
- # Input has multiple math.log(0.3).
- logits = torch.tensor(
- [[math.log(0.3), math.log(0.3), math.log(0.3), math.log(0.1)]]
- )
-
- result = apply_top_k_top_p_tpu(
- logits=logits.clone(), k=torch.tensor([4]), p=torch.tensor([0.2])
- )
-
- torch_xla.sync()
-
- # All tie values are included in the top-p set. Tie breaking is left
- # to be done during final sampling (all tie tokens have equal
- # probability of being chosen).
- expected_result = logits.clone().cpu()
- expected_result[0, 3] = float("-inf")
- assert torch.allclose(expected_result, result.cpu())
-
-
-def test_both_topk_topp():
- with torch.device(xm.xla_device()):
- logits = torch.tensor(
- [
- [math.log(0.2), math.log(0.3), math.log(0.5)],
- [math.log(0.5), math.log(0.1), math.log(0.4)],
- ]
- )
-
- # Set k=1 for the first batch.
- result = apply_top_k_top_p_tpu(
- logits=logits.clone(), k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79])
- )
-
- torch_xla.sync()
-
- # Since for the first batch k=1, expect only the largest element gets
- # selected.
- expected_result = logits.clone().cpu()
- expected_result[0, 0] = float("-inf")
- expected_result[0, 1] = float("-inf")
- expected_result[1, 1] = float("-inf")
- assert torch.allclose(expected_result, result.cpu())
diff --git a/tests/v1/tpu/test_tpu_int8.py b/tests/v1/tpu/test_tpu_int8.py
deleted file mode 100644
index 50001567a9588..0000000000000
--- a/tests/v1/tpu/test_tpu_int8.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests whether TPU Int8 computation is enabled correctly.
-
-Run `pytest tests/quantization/test_tpu_int8.py`.
-"""
-
-import pytest
-
-from vllm.model_executor.layers.linear import LinearBase
-from vllm.model_executor.layers.quantization.tpu_int8 import TPUInt8LinearMethod
-from vllm.platforms import current_platform
-
-from ...models.registry import HF_EXAMPLE_MODELS
-
-MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
-
-
-@pytest.mark.skipif(
- not current_platform.is_tpu(), reason="TPU Int8 is only enabled for TPUs."
-)
-@pytest.mark.parametrize("model", MODELS)
-@pytest.mark.parametrize("dtype", ["bfloat16"])
-@pytest.mark.parametrize("max_tokens", [10])
-@pytest.mark.parametrize(
- "hf_overrides",
- [
- # w8a8 dynamic activation
- {
- "quantization_config": {
- "quant_method": "tpu_int8",
- "activation_scheme": "dynamic",
- }
- }
- ],
-)
-def test_model_tpu_int8(
- vllm_runner,
- model: str,
- dtype: str,
- max_tokens: int,
- hf_overrides: dict,
- monkeypatch,
-) -> None:
- model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
- model_info.check_transformers_version(on_fail="skip")
-
- activation_scheme = hf_overrides.get("quantization_config", {}).get(
- "activation_scheme"
- )
- quantize_activation = activation_scheme == "dynamic"
-
- # Allows using apply_model
- monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
- # Prevent error from re-initializing cache
- monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
-
- prompts = [
- "A robot may not injure a human being",
- ]
- answers = [
- "or kill a human being",
- ]
-
- with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
-
- def check_model(model):
- for name, module in model.named_modules():
- if not isinstance(module, LinearBase):
- continue
- quant_method = module.quant_method
- assert isinstance(quant_method, TPUInt8LinearMethod)
- assert quant_method.quantize_activation == quantize_activation
-
- vllm.apply_model(check_model)
- outputs = vllm.generate_greedy(prompts, max_tokens)
- for (_, output), answer in zip(outputs, answers):
- assert answer in output
diff --git a/tests/v1/tpu/test_tpu_qkv_linear.py b/tests/v1/tpu/test_tpu_qkv_linear.py
deleted file mode 100644
index 098d925505424..0000000000000
--- a/tests/v1/tpu/test_tpu_qkv_linear.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import tempfile
-
-import numpy as np
-import pytest
-import torch
-import torch_xla.distributed.spmd as xs
-import torch_xla.runtime as xr
-
-from vllm.config import set_current_vllm_config
-from vllm.distributed.parallel_state import (
- ensure_model_parallel_initialized,
- init_distributed_environment,
-)
-from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
-from vllm.engine.arg_utils import EngineArgs
-from vllm.model_executor.layers.linear import QKVParallelLinear
-
-
-@pytest.fixture(autouse=True)
-def setup_environment():
- # This is a fake config used for init dist env.
- # QKVParallelLinear needs dist env to be initialized.
- engine_args = EngineArgs(
- model="Qwen/Qwen2-1.5B-Instruct",
- max_model_len=64,
- max_num_batched_tokens=64,
- max_num_seqs=4,
- )
-
- vllm_config = engine_args.create_engine_config()
-
- with set_current_vllm_config(vllm_config):
- temp_file = tempfile.mkstemp()[1]
- init_distributed_environment(
- 1,
- 0,
- local_rank=0,
- distributed_init_method=f"file://{temp_file}",
- backend="gloo",
- )
- ensure_model_parallel_initialized(1, 1)
- yield
-
-
-MESH = None
-
-
-def _get_spmd_mesh():
- global MESH
- if MESH is None:
- xr.use_spmd()
- num_devices = xr.global_runtime_device_count()
- mesh_shape = (num_devices, 1)
- device_ids = np.array(range(num_devices))
- MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
- return MESH
-
-
-@pytest.mark.parametrize("bias", [False, True])
-# `xr.use_spmd()` will set a global state, and this state is not reversible.
-# Therefore, non-SPMD tests should be run before SPMD tests.
-@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
-@pytest.mark.parametrize("device", ["cpu", "xla"])
-@torch.no_grad()
-def test_xla_qkv_linear(bias, mesh, device):
- torch.manual_seed(123)
-
- qkv_linear = QKVParallelLinear(
- hidden_size=4096,
- head_size=128,
- total_num_heads=32,
- total_num_kv_heads=8,
- bias=bias,
- params_dtype=torch.bfloat16,
- return_bias=False,
- )
-
- qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10
- if bias:
- qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data)
-
- xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh)
-
- qkv_linear = qkv_linear.to(device)
- xla_qkv_linear = xla_qkv_linear.to(device)
- input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10
- input_tensor = input_tensor.to(device)
-
- output = qkv_linear(input_tensor)
- xla_output = xla_qkv_linear(input_tensor)
- assert torch.allclose(output.cpu(), xla_output.cpu())
diff --git a/tests/v1/tpu/worker/__init__.py b/tests/v1/tpu/worker/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py
deleted file mode 100644
index cfc06666e7984..0000000000000
--- a/tests/v1/tpu/worker/test_tpu_model_runner.py
+++ /dev/null
@@ -1,587 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import pytest
-
-from vllm.attention.layer import Attention
-from vllm.config import (
- CacheConfig,
- ModelConfig,
- SchedulerConfig,
- VllmConfig,
- set_current_vllm_config,
-)
-from vllm.pooling_params import PoolingParams
-from vllm.sampling_params import SamplingParams
-from vllm.utils.mem_constants import GiB_bytes
-from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
-from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
-from vllm.v1.worker.tpu_model_runner import (
- TPUModelRunner,
- _get_padded_num_reqs_with_upper_limit,
- _get_padded_token_len,
- _get_req_paddings,
- _get_token_paddings,
-)
-
-
-def get_vllm_config():
- model_config = ModelConfig(
- model="facebook/opt-125m",
- dtype="bfloat16", # TPUs typically use bfloat16
- seed=42,
- )
- scheduler_config = SchedulerConfig(
- max_num_seqs=10,
- max_num_batched_tokens=512,
- max_model_len=512,
- is_encoder_decoder=model_config.is_encoder_decoder,
- )
- cache_config = CacheConfig(
- block_size=16,
- gpu_memory_utilization=0.9,
- swap_space=0,
- cache_dtype="auto",
- )
- vllm_config = VllmConfig(
- model_config=model_config,
- cache_config=cache_config,
- scheduler_config=scheduler_config,
- )
- return vllm_config
-
-
-def get_model_runner(vllm_config):
- device = "xla:0" # Mocking TPU device
- return TPUModelRunner(vllm_config, device)
-
-
-@pytest.fixture
-def model_runner():
- # Patchers have already been started at module level.
- vllm_config = get_vllm_config()
- return get_model_runner(vllm_config)
-
-
-def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
- new_reqs = []
- num_scheduled_tokens = {}
- total_num_scheduled_tokens = 0
- for req_id in req_ids:
- new_reqs.append(
- NewRequestData(
- req_id=req_id,
- prompt_token_ids=[1, 2, 3],
- mm_features=[],
- sampling_params=SamplingParams(),
- pooling_params=PoolingParams(),
- block_ids=([0],), # block_ids should be tuple[list[int]]
- num_computed_tokens=0,
- lora_request=None,
- )
- )
- num_scheduled_tokens[req_id] = 3
- total_num_scheduled_tokens += num_scheduled_tokens[req_id]
-
- return SchedulerOutput(
- scheduled_new_reqs=new_reqs,
- scheduled_cached_reqs=CachedRequestData.make_empty(),
- num_scheduled_tokens=num_scheduled_tokens,
- total_num_scheduled_tokens=total_num_scheduled_tokens,
- scheduled_spec_decode_tokens={},
- scheduled_encoder_inputs={},
- num_common_prefix_blocks=[],
- finished_req_ids=set(),
- free_encoder_mm_hashes=[],
- )
-
-
-def _is_req_scheduled(model_runner, req_id: str) -> bool:
- return req_id in model_runner.input_batch.req_id_to_index
-
-
-def _is_req_added(model_runner, req_id: str) -> bool:
- return req_id in model_runner.requests
-
-
-def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
- """Check if the request state block IDs match the block table.
-
- This function handles both legacy BlockTable and new MultiGroupBlockTable
- structures for backward compatibility.
- """
-
- req_index = model_runner.input_batch.req_id_to_index[req_id]
- multi_group_block_table = model_runner.input_batch.block_table
- req_state = model_runner.requests[req_id]
-
- # Access the first block table from MultiGroupBlockTable
- # This is safe since we currently only use single KV cache groups
- block_table = multi_group_block_table[0]
-
- # req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
- # Extract the first group's block IDs
- if isinstance(req_state.block_ids[0], list):
- # New format: tuple[list[int], ...] - extract first group
- req_block_ids = req_state.block_ids[0]
- else:
- # Legacy format: list[int] - use directly
- req_block_ids = req_state.block_ids
-
- if block_table.num_blocks_per_row[req_index] != len(req_block_ids):
- return False
-
- num_blocks = block_table.num_blocks_per_row[req_index]
- block_table_values = block_table.block_table.np[req_index, :num_blocks]
- return (block_table_values == req_block_ids).all()
-
-
-def test_update_states_new_request(model_runner):
- req_id = "req_0"
-
- # new req
- scheduler_output = _schedule_new_request(req_id)
-
- model_runner._update_states(scheduler_output)
-
- assert _is_req_added(model_runner, req_id)
- assert _is_req_scheduled(model_runner, req_id)
- assert _is_req_state_block_table_match(model_runner, req_id)
-
-
-def test_update_states_request_finished(model_runner):
- req_id = "req_0"
-
- # new req
- scheduler_output = _schedule_new_request(req_id)
-
- model_runner._update_states(scheduler_output)
- assert _is_req_added(model_runner, req_id)
- assert _is_req_scheduled(model_runner, req_id)
-
- # finish req
- scheduler_output = SchedulerOutput(
- scheduled_new_reqs=[],
- scheduled_cached_reqs=CachedRequestData.make_empty(),
- num_scheduled_tokens={},
- total_num_scheduled_tokens=0,
- scheduled_spec_decode_tokens={},
- scheduled_encoder_inputs={},
- num_common_prefix_blocks=[],
- finished_req_ids={req_id},
- free_encoder_mm_hashes=[],
- )
-
- model_runner._update_states(scheduler_output)
- assert not _is_req_added(model_runner, req_id)
- assert not _is_req_scheduled(model_runner, req_id)
-
-
-def test_update_states_request_resumed(model_runner):
- req_id = "req_0"
-
- # new req
- scheduler_output = _schedule_new_request(req_id)
-
- model_runner._update_states(scheduler_output)
- assert _is_req_added(model_runner, req_id)
- assert _is_req_scheduled(model_runner, req_id)
-
- # unschedule req
- scheduler_output = SchedulerOutput(
- scheduled_new_reqs=[],
- scheduled_cached_reqs=CachedRequestData.make_empty(),
- num_scheduled_tokens={},
- total_num_scheduled_tokens=0,
- scheduled_spec_decode_tokens={},
- scheduled_encoder_inputs={},
- num_common_prefix_blocks=[],
- finished_req_ids=set(),
- free_encoder_mm_hashes=[],
- )
-
- model_runner._update_states(scheduler_output)
- assert _is_req_added(model_runner, req_id)
- assert not _is_req_scheduled(model_runner, req_id)
-
- # resume req
- cached_req_data = CachedRequestData(
- req_ids=[req_id],
- resumed_req_ids={req_id},
- new_token_ids=[[]],
- all_token_ids={req_id: scheduler_output.scheduled_new_reqs[0].prompt_token_ids},
- new_block_ids=[([],)],
- num_computed_tokens=[0],
- num_output_tokens=[0],
- )
-
- scheduler_output = SchedulerOutput(
- scheduled_new_reqs=[],
- scheduled_cached_reqs=cached_req_data,
- num_scheduled_tokens={req_id: 1},
- total_num_scheduled_tokens=1,
- scheduled_spec_decode_tokens={},
- scheduled_encoder_inputs={},
- num_common_prefix_blocks=[],
- finished_req_ids=set(),
- free_encoder_mm_hashes=[],
- )
-
- model_runner._update_states(scheduler_output)
- assert _is_req_added(model_runner, req_id)
- assert _is_req_scheduled(model_runner, req_id)
- assert _is_req_state_block_table_match(model_runner, req_id)
-
-
-def test_update_states_no_changes(model_runner):
- req_id = "req_0"
-
- # new req
- scheduler_output = _schedule_new_request(req_id)
-
- model_runner._update_states(scheduler_output)
- assert _is_req_added(model_runner, req_id)
- assert _is_req_scheduled(model_runner, req_id)
-
- # schedule req
- scheduler_output = SchedulerOutput(
- scheduled_new_reqs=[],
- scheduled_cached_reqs=CachedRequestData.make_empty(),
- num_scheduled_tokens={req_id: 1},
- total_num_scheduled_tokens=1,
- scheduled_spec_decode_tokens={},
- scheduled_encoder_inputs={},
- num_common_prefix_blocks=[],
- finished_req_ids=set(),
- free_encoder_mm_hashes=[],
- )
-
- model_runner._update_states(scheduler_output)
- assert _is_req_added(model_runner, req_id)
- assert _is_req_scheduled(model_runner, req_id)
- assert _is_req_state_block_table_match(model_runner, req_id)
-
-
-def test_update_states_request_unscheduled(model_runner):
- req_ids = ("req_0", "req_1")
-
- # new reqs
- scheduler_output = _schedule_new_request(*req_ids)
-
- model_runner._update_states(scheduler_output)
-
- assert _is_req_added(model_runner, req_ids[0])
- assert _is_req_scheduled(model_runner, req_ids[0])
-
- assert _is_req_added(model_runner, req_ids[1])
- assert _is_req_scheduled(model_runner, req_ids[1])
-
- # unschedule req_1
- scheduler_output = SchedulerOutput(
- scheduled_new_reqs=[],
- scheduled_cached_reqs=CachedRequestData.make_empty(),
- num_scheduled_tokens={req_ids[0]: 1},
- total_num_scheduled_tokens=1,
- scheduled_spec_decode_tokens={},
- scheduled_encoder_inputs={},
- num_common_prefix_blocks=[],
- finished_req_ids=set(),
- free_encoder_mm_hashes=[],
- )
-
- model_runner._update_states(scheduler_output)
-
- assert _is_req_added(model_runner, req_ids[0])
- assert _is_req_scheduled(model_runner, req_ids[0])
-
- assert _is_req_added(model_runner, req_ids[1])
- assert not _is_req_scheduled(model_runner, req_ids[1])
-
-
-def test_get_paddings():
- # Bucketed padding
- min_token_size, max_token_size, padding_gap = 16, 512, 64
- expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
- actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
-
- # Bucketed padding with max_token_size not a power of two.
- max_token_size = 317
- expected_paddings = [16, 32, 64, 128, 192, 256, 320]
- actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
- assert actual_paddings == expected_paddings
-
- # Exponential padding.
- max_token_size, padding_gap = 1024, 0
- expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
- actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
- assert actual_paddings == expected_paddings
- # Exponential padding with max_token_size not a power of two.
- max_token_size = 317
- expected_paddings = [16, 32, 64, 128, 256, 512]
- actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
- assert actual_paddings == expected_paddings
-
-
-def test_get_padded_token_len():
- min_token_size, max_token_size, padding_gap = 16, 512, 64
- paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
- assert _get_padded_token_len(paddings, 1) == 16
- assert _get_padded_token_len(paddings, 16) == 16
- assert _get_padded_token_len(paddings, 20) == 32
- assert _get_padded_token_len(paddings, 300) == 320
- assert _get_padded_token_len(paddings, 512) == 512
-
-
-def test_get_padded_num_reqs_with_upper_limit():
- assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8
- assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16
- assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32
- assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28
-
-
-def test_get_req_paddings():
- assert _get_req_paddings(1, 32) == [8, 16, 32]
- assert _get_req_paddings(8, 32) == [8, 16, 32]
- assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
-
-
-def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(model_runner):
- layer_0 = "model.layers.0.self_attn.attn"
- layer_1 = "model.layers.1.self_attn.attn"
- error_msg = f"{layer_1} must come before the current layer"
- vllm_config = model_runner.vllm_config
- with (
- pytest.raises(ValueError, match=error_msg),
- set_current_vllm_config(vllm_config),
- ):
- fwd_context = {
- # initialization below will fail because target layer is invalid;
- # the target layer needs to come before layer 1
- layer_0: Attention(
- num_heads=8,
- head_size=128,
- scale=1.0,
- prefix=layer_0,
- kv_sharing_target_layer_name=layer_1,
- ),
- layer_1: Attention(
- num_heads=8,
- head_size=128,
- scale=1.0,
- prefix=layer_1,
- ),
- }
- # suppress var not used error
- assert fwd_context is not None
-
-
-def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
- layer_0 = "model.layers.0.self_attn.attn"
- layer_1 = "model.layers.1.self_attn.attn"
- invalid_layer = "model.layers.0.cross_attn.attn"
- error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
- vllm_config = model_runner.vllm_config
- with (
- pytest.raises(ValueError, match=error_msg),
- set_current_vllm_config(vllm_config),
- ):
- fwd_context = {
- layer_0: Attention(
- num_heads=8,
- head_size=128,
- scale=1.0,
- prefix=layer_0,
- ),
- layer_1: Attention(
- num_heads=8,
- head_size=128,
- scale=1.0,
- prefix=layer_1,
- # invalid layer: cross_attn.atn doesn't exist!
- kv_sharing_target_layer_name=invalid_layer,
- ),
- }
- # suppress var not used error
- assert fwd_context is not None
-
-
-def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner):
- layer_0 = "model.layers.0.self_attn.attn"
- layer_1 = "model.layers.1.self_attn.attn"
- error_msg = f"{layer_1} cannot be the same as the current layer"
- vllm_config = model_runner.vllm_config
- with (
- pytest.raises(ValueError, match=error_msg),
- set_current_vllm_config(vllm_config),
- ):
- fwd_context = {
- # initialization below will fail because target layer is invalid;
- # the target layer needs to come before layer 1
- layer_0: Attention(
- num_heads=8,
- head_size=128,
- scale=1.0,
- prefix=layer_0,
- ),
- layer_1: Attention(
- num_heads=8,
- head_size=128,
- scale=1.0,
- prefix=layer_1,
- kv_sharing_target_layer_name=layer_1,
- ),
- }
- # suppress var not used error
- assert fwd_context is not None
-
-
-def test_init_kv_cache_without_kv_sharing():
- layer_0 = "model.layers.0.self_attn.attn"
- layer_1 = "model.layers.1.self_attn.attn"
- vllm_config = get_vllm_config()
- with set_current_vllm_config(vllm_config):
- fwd_context = {
- layer_0: Attention(
- num_heads=8,
- head_size=128,
- scale=1.0,
- prefix=layer_0,
- ),
- layer_1: Attention(
- num_heads=8,
- head_size=128,
- scale=1.0,
- prefix=layer_1,
- ),
- }
- # suppress var not used error
- assert fwd_context is not None
- # Set high context length to test max context length estimation
- vllm_config.model_config.max_model_len = 1_000_000
- vllm_ctx = vllm_config.compilation_config.static_forward_context
- model_runner = get_model_runner(vllm_config)
- kv_cache_spec = model_runner.get_kv_cache_spec()
- assert len(kv_cache_spec) == 2
- assert len(model_runner.shared_kv_cache_layers) == 0
-
- available_memory = 20 * GiB_bytes
- # page size for each layer KV can be calculated as
- # 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
- # * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
- num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
- kv_cache_config = get_kv_cache_configs(
- vllm_config, [kv_cache_spec], [available_memory]
- )[0]
- assert kv_cache_config.num_blocks == num_expected_blocks
- assert len(kv_cache_config.kv_cache_tensors) == 2
- assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
- assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
-
- max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
- # max context len with KV sharing should be 2x as large as without
- # max_context_len = available_memory / (page_size / block_size) / num_caches
- # max_context_len = 5GB / (512KB / 128) / 2 = 655360
- assert max_context_len == 655360
-
- # important: override tensor size to prevent large mem alloc during test
- # this will only allocate 2 block worth of memory (2 * 512kb)
- kv_cache_config.num_blocks = 1
- for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
- kv_cache_tensor.size = kv_cache_spec[
- kv_cache_tensor.shared_by[0]
- ].page_size_bytes
-
- model_runner.initialize_kv_cache(kv_cache_config)
-
- layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
- layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
- # check layer 1 kv cache does NOT share memory with layer 0
- assert id(layer_1_kv) != id(layer_0_kv)
-
- # check layer 1 added to kv cache group's layer names
- assert len(kv_cache_config.kv_cache_groups) == 1
- assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
- assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
- assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
-
-
-def test_init_kv_cache_with_kv_sharing_valid():
- layer_0 = "model.layers.0.self_attn.attn"
- layer_1 = "model.layers.1.self_attn.attn"
- vllm_config = get_vllm_config()
- with set_current_vllm_config(vllm_config):
- fwd_context = {
- layer_0: Attention(
- num_heads=8,
- head_size=128,
- scale=1.0,
- prefix=layer_0,
- ),
- layer_1: Attention(
- num_heads=8,
- head_size=128,
- scale=1.0,
- prefix=layer_1,
- kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
- ),
- }
- # suppress var not used error
- assert fwd_context is not None
- # Set high context length to test max context length estimation
- vllm_config.model_config.max_model_len = 3_000_000
- vllm_ctx = vllm_config.compilation_config.static_forward_context
- model_runner = get_model_runner(vllm_config)
- kv_cache_spec = model_runner.get_kv_cache_spec()
- assert len(kv_cache_spec) == 1
- assert layer_0 in kv_cache_spec
- assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
-
- available_memory = 20 * GiB_bytes
- # page size for layer 0's kv_cache_spec is 512KB
- # with KV sharing, we can allocate (available_mem//page_size//1) blocks
- # which is twice as many as without KV sharing
- num_expected_blocks = 2 * 20480 # 20GB / 512KB
- kv_cache_config = get_kv_cache_configs(
- vllm_config, [kv_cache_spec], [available_memory]
- )[0]
- assert kv_cache_config.num_blocks == num_expected_blocks
- assert len(kv_cache_config.kv_cache_tensors) == 1
- # Each layer now has twice the available memory for KV cache
- # compared to no KV sharing
- assert kv_cache_config.kv_cache_tensors[0].size == available_memory
-
- max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
- # max context len with KV sharing should be 2x as large as without
- assert max_context_len == (2 * 655360)
-
- # important: override tensor size to prevent large mem alloc during test
- # this will only allocate 1 block worth of memory (512kb)
- kv_cache_config.num_blocks = 1
- kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes
-
- model_runner.initialize_kv_cache(kv_cache_config)
-
- layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
- layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
- # check layer 1 kv cache shares memory with layer 0
- assert id(layer_1_kv) == id(layer_0_kv)
-
- # check layer 1 added to kv cache group's layer names
- assert len(kv_cache_config.kv_cache_groups) == 1
- assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
- assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
- assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
-
-
-def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
- monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
- vllm_config = get_vllm_config()
- vllm_config.model_config.max_model_len = 32000
- vllm_config.scheduler_config.max_num_seqs = 1200
- model_runner = get_model_runner(vllm_config)
-
- # verify model runner will adjust num_reqs to avoid SMEM OOM.
- assert model_runner.num_reqs_most_model_len == 1200
- # num_page_per_req = 32k // 128
- # num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
- assert model_runner.num_reqs_max_model_len == 524
diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py
index 416b996df9f22..77724a3a1915c 100644
--- a/vllm/attention/backends/registry.py
+++ b/vllm/attention/backends/registry.py
@@ -66,7 +66,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
)
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
- PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
diff --git a/vllm/attention/layers/mm_encoder_attention.py b/vllm/attention/layers/mm_encoder_attention.py
index 1c1623b13f55a..138fc99114127 100644
--- a/vllm/attention/layers/mm_encoder_attention.py
+++ b/vllm/attention/layers/mm_encoder_attention.py
@@ -227,28 +227,3 @@ class MMEncoderAttention(CustomOp):
"XPU only supports FLASH_ATTN for vision attention."
)
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
-
- def forward_tpu(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- cu_seqlens: torch.Tensor | None = None,
- max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
- ) -> torch.Tensor:
- assert self.attn_backend == AttentionBackendEnum.PALLAS, (
- f"MMEncoderAttention on TPU only supports PALLAS backend, "
- f"but got {self.attn_backend}."
- )
- if cu_seqlens is None:
- query, key, value = (x.transpose(1, 2) for x in (query, key, value))
- from torch_xla.experimental.custom_kernel import flash_attention
-
- out = flash_attention(query, key, value, sm_scale=self.scale)
- out = out.transpose(1, 2)
- return out
- logger.warning_once(
- "PALLAS backend with cu_seqlens is not supported for ViT yet. ",
- "Falling back to SDPA implementation.",
- )
- return self._forward_sdpa(query, key, value, cu_seqlens)
diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py
deleted file mode 100644
index fa99078e9ff0d..0000000000000
--- a/vllm/distributed/device_communicators/tpu_communicator.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import os
-
-import torch
-from torch.distributed import ProcessGroup
-
-from vllm.config import get_current_vllm_config
-from vllm.logger import init_logger
-from vllm.platforms import current_platform
-from vllm.platforms.tpu import USE_TPU_INFERENCE
-
-from .base_device_communicator import DeviceCommunicatorBase
-
-USE_RAY = parallel_config = (
- get_current_vllm_config().parallel_config.distributed_executor_backend == "ray"
-)
-
-logger = init_logger(__name__)
-
-if not USE_TPU_INFERENCE:
- logger.info("tpu_inference not found, using vLLM's TpuCommunicator")
- if current_platform.is_tpu():
- import torch_xla
- import torch_xla.core.xla_model as xm
- import torch_xla.runtime as xr
- from torch_xla._internal import pjrt
- from torch_xla.distributed.xla_multiprocessing import (
- create_optimized_replica_groups,
- )
-
- if USE_RAY:
- from vllm.v1.executor import ray_utils
-
-
-class TpuCommunicator(DeviceCommunicatorBase):
- def __init__(
- self,
- cpu_group: ProcessGroup,
- device: torch.device | None = None,
- device_group: ProcessGroup | None = None,
- unique_name: str = "",
- ):
- super().__init__(cpu_group, device, device_group, unique_name)
-
- # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
- # must be used together. Therefore, the local rank and world size can
- # be simply calculated as follows.
- global_rank = self.global_rank
- global_world_size = self.global_world_size
-
- if USE_RAY:
- logger.info("TpuCommunicator initialized with RAY")
- # Calculate how many TPU nodes are in the current deployment. This
- # is the Ray placement group if it is deployed with Ray. Default
- # to the number of TPU nodes in the Ray cluster. The number of TPU
- # nodes is computed by the total number of TPUs divided by the
- # number of TPU accelerators per node, to account for clusters
- # with both CPUs and TPUs.
- num_nodes = ray_utils.get_num_tpu_nodes()
- num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
- if num_nodes_in_pg > 0:
- num_nodes = num_nodes_in_pg
-
- local_world_size = global_world_size // num_nodes
- local_rank = global_rank % local_world_size
- else:
- logger.info("TpuCommunicator initialized with MP")
- # Sanity: Verify we run on a single host
- num_hosts = torch_xla.tpu.num_tpu_workers()
- assert num_hosts == 1
-
- # Get the current number of TPUs (we have locally)
- local_world_size = torch_xla.tpu.num_available_chips()
-
- # Get current rank
- local_rank = global_rank % local_world_size
-
- # Ensure environment variables are set for multihost deployments.
- # On GKE, this is needed for libtpu and TPU driver to know which TPU
- # chip is actually visible. Otherwise the TPU driver will fail to
- # initialize because the number of devices would be different from
- # the number of visible worker addresses.
- os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
- os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
-
- pjrt.initialize_multiprocess(local_rank, local_world_size)
- xr._init_world_size_ordinal()
- self.groups = create_optimized_replica_groups()
-
- def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
- # TODO: Remove the groups specification after XLA compiler can support
- # auto-reordering the ring order for all-reduce.
- return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
-
- def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
- assert dim == -1, "TPUs only support dim=-1 for all-gather."
- return xm.all_gather(input_, dim=dim)
diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
index 4f1ea1a0240c4..914ab91b1563c 100644
--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Literal
import torch
from vllm.attention.backends.abstract import AttentionBackend
-from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger
@@ -251,9 +250,6 @@ class TpKVTopology:
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
- attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
- self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
-
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@@ -261,7 +257,7 @@ class TpKVTopology:
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
- return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
+ return not (self.is_mla or self.is_kv_layout_blocks_first)
@property
def tp_size(self) -> int:
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py
index 9a15d3fa6ed09..38ce02a2fef76 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py
@@ -499,7 +499,6 @@ class MooncakeConnectorWorker:
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
- self._use_pallas = self.kv_topo._use_pallas
self.zmq_ctx = zmq.Context()
self.async_zmq_ctx = zmq.asyncio.Context()
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
index 757ca41e9844b..0f33cde7d3221 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@@ -983,7 +983,6 @@ class NixlConnectorWorker:
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
- self._use_pallas = self.kv_topo._use_pallas
self._physical_blocks_per_logical_kv_block = 1
def _nixl_handshake(
@@ -1641,9 +1640,6 @@ class NixlConnectorWorker:
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
- assert not self._use_pallas or tp_ratio == 1, (
- "TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
- )
kv_cache_layout = (
self.kv_cache_layout
if not self.use_host_buffer
@@ -1814,9 +1810,7 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0:
return
- split_k_and_v = not (
- self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
- )
+ split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
sample_cache = list(self.device_kv_caches.values())[0][0]
for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
assert block_size_ratio > 1, "Only nP < nD supported currently."
diff --git a/vllm/distributed/tpu_distributed_utils.py b/vllm/distributed/tpu_distributed_utils.py
deleted file mode 100644
index 4ff1f0ce4410a..0000000000000
--- a/vllm/distributed/tpu_distributed_utils.py
+++ /dev/null
@@ -1,188 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections import OrderedDict
-from typing import Optional
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch_xla.distributed.spmd as xs
-from torch.nn.parameter import Parameter
-
-from vllm.logger import init_logger
-from vllm.model_executor.layers.linear import (
- ColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear,
-)
-
-logger = init_logger(__name__)
-
-
-class XlaQKVParallelLinear(nn.Module):
- def __init__(self, qkv_linear: nn.Module, mesh: Optional["xs.Mesh"] = None):
- super().__init__()
- assert isinstance(qkv_linear, QKVParallelLinear)
- self.skip_bias_add = qkv_linear.skip_bias_add
- self.return_bias = qkv_linear.return_bias
- assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."
-
- self.q_weight: Parameter
- self.k_weight: Parameter
- self.v_weight: Parameter
- self.q_bias: Parameter | None
- self.k_bias: Parameter | None
- self.v_bias: Parameter | None
- self._load_weights_from_qkv_linear(qkv_linear)
- if mesh is not None:
- self._shard_weight(mesh)
-
- def _shard_weight(self, mesh: "xs.Mesh"):
- self.q_weight = Parameter(self.q_weight.to("xla"), requires_grad=False)
- self.k_weight = Parameter(self.k_weight.to("xla"), requires_grad=False)
- self.v_weight = Parameter(self.v_weight.to("xla"), requires_grad=False)
- xs.mark_sharding(self.q_weight, mesh, ("x", None))
- xs.mark_sharding(self.k_weight, mesh, ("x", None))
- xs.mark_sharding(self.v_weight, mesh, ("x", None))
- if self.q_bias is not None:
- assert self.k_bias is not None and self.v_bias is not None, (
- "QKVParallelLinear should have q, k, and v biases together."
- )
- self.q_bias = Parameter(self.q_bias.to("xla"), requires_grad=False)
- xs.mark_sharding(self.q_bias, mesh, ("x",))
- self.k_bias = Parameter(self.k_bias.to("xla"), requires_grad=False)
- xs.mark_sharding(self.k_bias, mesh, ("x",))
- self.v_bias = Parameter(self.v_bias.to("xla"), requires_grad=False)
- xs.mark_sharding(self.v_bias, mesh, ("x",))
-
- def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
- q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
- # The weight of qkv linear is a concatenation of q, k, and v weights
- # along the output dimension.
- qkv_weight = qkv_linear.weight.data.cpu()
- q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
- k_weight = Parameter(
- qkv_weight[q_proj_size : q_proj_size + k_proj_size], requires_grad=False
- )
- v_weight = Parameter(
- qkv_weight[q_proj_size + k_proj_size :], requires_grad=False
- )
- self.register_parameter("q_weight", q_weight)
- self.register_parameter("k_weight", k_weight)
- self.register_parameter("v_weight", v_weight)
-
- if qkv_linear.bias is not None:
- q_bias = Parameter(qkv_linear.bias[:q_proj_size], requires_grad=False)
- k_bias = Parameter(
- qkv_linear.bias[q_proj_size : q_proj_size + k_proj_size],
- requires_grad=False,
- )
- v_bias = Parameter(
- qkv_linear.bias[q_proj_size + k_proj_size :], requires_grad=False
- )
- self.register_parameter("q_bias", q_bias)
- self.register_parameter("k_bias", k_bias)
- self.register_parameter("v_bias", v_bias)
- else:
- self.register_parameter("q_bias", None)
- self.register_parameter("k_bias", None)
- self.register_parameter("v_bias", None)
-
- def forward(self, input):
- # Same forward functionality as QKVParallelLinear, but doing qkv porj
- # separately.
- q_bias = self.q_bias if not self.skip_bias_add else None
- k_bias = self.k_bias if not self.skip_bias_add else None
- v_bias = self.v_bias if not self.skip_bias_add else None
- q_proj = F.linear(input, self.q_weight, q_bias)
- k_proj = F.linear(input, self.k_weight, k_bias)
- v_proj = F.linear(input, self.v_weight, v_bias)
- # The q/k/v projections will be split outside of the QKVParallelLinear.
- # Because we are replacing XlaQKVParallelLinear with the
- # QKVParallelLinear, we need to concatenate q, k, and v projections to
- # match the output shape of the QKVParallelLinear implementation even if
- # it seems to be redundant.
- # The concat and the following split will be noop, and should be
- # optimized away by the compiler.
- qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
- output_bias = (
- torch.cat([q_bias, k_bias, v_bias], dim=-1) if self.skip_bias_add else None
- )
- if not self.return_bias:
- return qkv_proj
- return qkv_proj, output_bias
-
-
-def partition_column_parallel_linear(
- layer: torch.nn.Module, mesh: xs.Mesh
-) -> torch.nn.Module:
- assert isinstance(layer, ColumnParallelLinear)
- xs.mark_sharding(layer.weight, mesh, ("x", None))
- logger.debug("Applied column-parallel sharding to %s", layer)
- return layer
-
-
-def partition_row_parallel_linear(
- layer: torch.nn.Module, mesh: xs.Mesh
-) -> torch.nn.Module:
- assert isinstance(layer, RowParallelLinear)
- xs.mark_sharding(layer.weight, mesh, (None, "x"))
- logger.debug("Applied row-parallel sharding to %s", layer)
- return layer
-
-
-def partition_qkv_parallel_linear(
- layer: torch.nn.Module, mesh: xs.Mesh
-) -> torch.nn.Module:
- assert isinstance(layer, QKVParallelLinear)
- xla_layer = XlaQKVParallelLinear(layer, mesh)
- logger.debug("Applied qkv parallel sharding to %s", layer)
- return xla_layer
-
-
-MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict(
- [
- ("QKVParallelLinear", partition_qkv_parallel_linear),
- ("ColumnParallelLinear", partition_column_parallel_linear),
- ("RowParallelLinear", partition_row_parallel_linear),
- ]
-)
-
-
-def get_fqn(module):
- # Get the fully qualified name of the module
- return module.__class__.__qualname__
-
-
-def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
- """
- Recursively check a PyTorch model and apply appropriate sharding based on
- the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
-
- Args:
- model: torch.nn.Module to process
- mesh: An XLA SPMD mesh object used for sharding
- """
-
- def _process_module(module, name=None, parent=None):
- for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
- if get_fqn(module) == module_type:
- wrapped_module = wrapping_func(module, mesh)
-
- assert parent is not None and name is not None, (
- "Top Level module is not expected to be wrapped."
- )
- if wrapped_module is not module:
- # Wrapped module and module are different py object.
- # The original module should be replaced by the
- # wrapped_module.
- logger.debug("replace %s with %s", module, wrapped_module)
- setattr(parent, name, wrapped_module)
-
- module = wrapped_module
- break
-
- for child_name, child_module in list(module.named_children()):
- _process_module(child_module, child_name, module)
-
- _process_module(model)
diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py
deleted file mode 100644
index b5570ceca68ca..0000000000000
--- a/vllm/lora/ops/xla_ops/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
-
-__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]
diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py
deleted file mode 100644
index 4924890b388cb..0000000000000
--- a/vllm/lora/ops/xla_ops/lora_ops.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import jax
-import jax.numpy as jnp
-import torch
-import torch.nn.functional as F
-import torch_xla.core.xla_builder as xb
-from torch.library import impl
-from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
-
-
-@jax.jit
-def bgmv_jax(inputs, loras, idxs):
- return jnp.einsum(
- "td,tX,Xld->tl",
- inputs,
- jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
- loras,
- )
-
-
-XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
-
-
-@impl(XLA_LIB, "bgmv", "XLA")
-def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
- if len(loras.shape) == 4:
- loras = loras.squeeze(axis=1)
-
- jax_import_guard()
- return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
-
-
-@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
-def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
- T, _ = inputs.shape
- if len(loras.shape) == 4:
- loras = loras.squeeze(axis=1)
- _, L, _ = loras.shape
-
- return torch.empty((T, L), device=inputs.device)
-
-
-def bgmv_expand(
- inputs: torch.Tensor,
- lora_b_weights: torch.Tensor,
- output_tensor: torch.Tensor,
- lora_indices_tensor: torch.Tensor,
- add_inputs: bool = True,
-):
- """
- Args:
- inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
-
- lora_b_weights (torch.Tensor): LoRA weights of shape
- [num_loras, lora_rank, hidden_size].
-
- output_tensor (torch.Tensor): output tensor of shape
- [num_tokens, hidden_size * num_slices].
-
- lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
- indicating which LoRA matrix to use for each token.
- add_inputs (bool): Whether or not to add the input tensor to the output
- tensor.
- """
-
- outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
-
- limit = output_tensor.shape[0]
- if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
- limit = 1
-
- if output_tensor.shape[1] > outputs.shape[1]:
- outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
-
- if add_inputs:
- return output_tensor + outputs[:limit, : output_tensor.shape[1]]
- else:
- return outputs[:limit, : output_tensor.shape[1]]
-
-
-def bgmv_shrink(
- inputs: torch.Tensor,
- lora_b_weights: torch.Tensor,
- lora_indices_tensor: torch.Tensor,
- scaling: float = 1.0,
-):
- """
- Args:
- inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
- lora_b_weights (torch.Tensor): LoRA weights of shape
- [num_loras, lora_rank, hidden_size].
- lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
- indicating which LoRA matrix to use for each token.
- scaling (float, optional): Scalar multiplier applied to the output.
- """
-
- return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
-
-
-def bgmv_expand_slice(
- inputs: torch.Tensor,
- lora_b_weights: torch.Tensor,
- output_tensor: torch.Tensor,
- lora_indices_tensor: torch.Tensor,
- slice_offset: int,
- slice_size: int,
- add_inputs: bool = True,
-):
- """
- Args:
- inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
-
- lora_b_weights (torch.Tensor): LoRA weights of shape
- [num_loras, lora_rank, hidden_size].
-
- output_tensor (torch.Tensor): output tensor of shape
- [num_tokens, hidden_size * num_slices].
-
- lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
- indicating which LoRA matrix to use for each token.
- add_inputs (bool): Whether or not to add the input tensor to the output
- tensor.
- """
- outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
-
- outputs = F.pad(
- outputs,
- (
- slice_offset,
- output_tensor.shape[1] - (slice_offset + slice_size),
- 0,
- 0,
- ),
- )
-
- if add_inputs:
- return output_tensor + outputs
- else:
- return outputs
diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py
deleted file mode 100644
index 0888772db54e7..0000000000000
--- a/vllm/lora/punica_wrapper/punica_tpu.py
+++ /dev/null
@@ -1,358 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import math
-from typing import TYPE_CHECKING
-
-import torch
-import torch.nn.functional as F
-import torch_xla
-
-from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
-from vllm.lora.punica_wrapper.utils import convert_mapping
-
-if TYPE_CHECKING:
- # avoid circuit import
- from vllm.lora.layers import LoRAMapping
-
-from .punica_base import PunicaWrapperBase
-
-
-class PunicaWrapperTPU(PunicaWrapperBase):
- """
- PunicaWrapperTPU is designed to manage and provide metadata for the punica
- kernel. The main function is to maintain the state information for
- Multi-LoRA, and to provide the interface for the pytorch punica ops.
- """
-
- def __init__(
- self,
- max_num_batched_tokens: int,
- max_batches: int,
- device: torch.device | str,
- **kwargs,
- ):
- PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
-
- # PunicaWrapperBase defines some tensors with dtype=torch.int64, which
- # isn't supported by the TPU. So convert those tensors to int32.
- # Not all of them are used by the TPU so only convert the useful ones.
- self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32)
- self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
- self._sampler_indices_padded = self._sampler_indices_padded.to(
- dtype=torch.int32
- )
-
- torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
- torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
- torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True)
- torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
- torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True)
-
- torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
- torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
- torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
-
- def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
- return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
-
- @property
- def embeddings_indices(self) -> torch.Tensor:
- """
- This property provides access to the indices used for lora embeddings,
- specifically for VocabParallelEmbeddingWithLoRA.
- """
- return self._embeddings_indices[:]
-
- @property
- def sampler_indices_padded(self) -> torch.Tensor:
- """
- This property provides access to padded sampler indices.
- """
- return self._sampler_indices_padded[:]
-
- def shrink(
- self,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- scale: float,
- ):
- return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
-
- def expand(
- self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool
- ):
- return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs)
-
- def expand_slice(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- y_offset: int,
- y_slice_size: int,
- add_inputs: bool,
- ) -> torch.Tensor:
- return bgmv_expand_slice(
- x,
- w_t_all,
- y,
- self._get_token_lora_indices(x),
- y_offset,
- y_slice_size,
- add_inputs,
- )
-
- def add_shrink(
- self,
- y: tuple[torch.Tensor, ...] | torch.Tensor,
- x: torch.Tensor,
- lora_a_stacked: tuple[torch.Tensor, ...],
- scale: float,
- **kwargs,
- ) -> torch.Tensor | None:
- """
- Performs GEMM for multiple slices of lora_a.
-
- Semantics:
- for i in range(len(lora_a_stacked)):
- y[i] += (x @ lora_a_stacked[i]) * scale
-
- Args:
- y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
- x (torch.Tensor): Input tensor
- lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
- scale (float): Scaling factor for the operation
- """
-
- torch.ops.xla.dynamo_set_buffer_donor_(y, True)
- x = x.view(-1, x.shape[-1])
-
- for slice_idx in range(len(lora_a_stacked)):
- lora_s = lora_a_stacked[slice_idx]
- y_s = self.shrink(x, lora_s, scale)
- y[slice_idx, :, :] = y_s # type: ignore[index]
- return y
-
- def add_expand(
- self,
- y: torch.Tensor,
- x: tuple[torch.Tensor, ...] | torch.Tensor,
- lora_b_stacked: tuple[torch.Tensor, ...],
- output_slices: tuple[int, ...],
- offset_start: int = 0,
- add_inputs=True,
- **kwargs,
- ) -> torch.Tensor:
- """
- Performs GEMM for multiple slices of lora_b.
-
- Semantics:
- for i in range(len(lora_b_stacked)):
- slice = output_slices[i]
- y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
- offset += slice
-
- Args:
- y (torch.Tensor): Output tensor.
- x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
- lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
- output_slices (tuple[int, ...]): Every slice's size
- add_inputs (bool): Defaults to True.
- """
- y_org = y
- y = y.view(-1, y.shape[-1])
- offset_left = 0
-
- for slice_idx in range(len(lora_b_stacked)):
- y = self.expand_slice(
- y,
- x[slice_idx],
- lora_b_stacked[slice_idx],
- offset_left,
- output_slices[slice_idx],
- add_inputs=add_inputs,
- )
- offset_left += output_slices[slice_idx]
- return y.view_as(y_org)
-
- def add_lora_embedding(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- lora_b_stacked: torch.Tensor,
- add_inputs: bool = True,
- **kwargs,
- ) -> torch.Tensor:
- """
- Applies lora specifically for VocabParallelEmbeddingWithLoRA.
-
- Semantics:
- y += x @ lora_b_stacked
-
- Args:
- y (torch.Tensor): Output tensor.
- x (torch.Tensor): Input tensor.
- lora_b_stacked (torch.Tensor): lora_b's weights.
- add_inputs (bool): Default to True.
- """
-
- # Embedding layer only needs the expand op
- return self.expand(y, x, lora_b_stacked, add_inputs)
-
- def add_lora_linear(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- lora_a_stacked: tuple[torch.Tensor, ...],
- lora_b_stacked: tuple[torch.Tensor, ...],
- scale: float,
- output_slices: tuple[int, ...],
- *,
- buffer: tuple[torch.Tensor, ...] | None = None,
- **kwargs,
- ) -> torch.Tensor:
- """
- Applicable to linear-related lora.
-
- Semantics:
- for i in range(len(lora_a_stacked)):
- y[i] += (
- x[i].unsqueeze(0)
- @ lora_a_stacked[indices[i], layer_idx, :, :]
- @ lora_b_stacked[indices[i], layer_idx, :, :]
- * scale
- ).squeeze(0)
-
- Args:
- y (torch.Tensor): Output tensor. Will not be changed in-place.
- x (torch.Tensor): Input tensor (T, E)
- lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
- lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
- scale (float): Scaling factor.
- output_slices (tuple[int, ...]): Every slice's size.
- buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
- """
-
- assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
-
- if buffer is None:
- r = lora_b_stacked[0].size(-1)
- T = x.size(0)
- buffer = torch.zeros(
- (len(output_slices), T, r),
- dtype=x.dtype,
- device=x.device,
- )
- buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
- return self.add_expand(
- y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
- )
-
- def add_lora_logits(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- lora_a_stacked: torch.Tensor,
- lora_b_stacked: torch.Tensor,
- scale,
- *,
- buffer: torch.Tensor | None = None,
- **kwargs,
- ) -> torch.Tensor:
- """
- Applies lora specifically for LogitsProcessorWithLoRA.
-
- Semantics:
- buffer = (x @ lora_a_stacked) * scale
- y += buffer @ lora_b_stacked
-
- Args:
- y (torch.Tensor): Output tensor.
- x (torch.Tensor): Input tensor.
- lora_a_stacked (torch.Tensor): lora_a's weights.
- lora_b_stacked (torch.Tensor):lora_b's weights.
- scale (float): Scaling factor.
- buffer (Optional[torch.Tensor]):Default to None.
- """
- y_org = y
- y = y.view(-1, y.shape[-1])
- x = x.view(-1, x.shape[-1])
-
- sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
- buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
- y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
- return y.view_as(y_org)
-
- # This performs the same tensor ops as the base method, except it does them
- # on the CPU then transfers the results to the TPU
- def _update_base_metadata(
- self,
- mapping: "LoRAMapping",
- lora_index_to_id: list[int | None],
- max_loras: int,
- vocab_size: int,
- ):
- # Make sure we don't accidentally collect outside operations
- torch_xla.sync()
-
- # Pad the prompt mapping to avoid running into recompiles on the TPU
- # TODO: Should this happen inside mapping internally? If so how can we
- # avoid having backend specific LoRAMapping classes?
- mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping)
-
- (
- base_indices,
- sampler_indices,
- sampler_indices_padded,
- embeddings_indices,
- indices_len,
- ) = convert_mapping(
- mapping,
- lora_index_to_id,
- max_loras,
- vocab_size,
- 0, # extra_vocab_size
- "cpu",
- )
- self._token_lora_indices = self._pad_to_shape(
- base_indices, self._token_lora_indices.shape, dims=1
- ).to(self.device)
- self._sampler_indices = self._pad_to_shape(
- sampler_indices, self._sampler_indices.shape, dims=1
- ).to(self.device)
- self._sampler_indices_padded = self._pad_to_shape(
- sampler_indices_padded, self._sampler_indices_padded.shape, dims=1
- ).to(self.device)
- self._embeddings_indices = self._pad_to_shape(
- embeddings_indices, self._embeddings_indices.shape, dims=2
- ).to(self.device)
- self.indices_len[:] = indices_len
-
- def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
- self.batch_size = 1
- self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[
- : self.batch_size
- ]
-
- def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
- num_reqs = len(prompt_mapping)
-
- # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
- # import
- MIN_NUM_SEQS = 8
-
- padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
- pad_len = padded_num_reqs - num_reqs
-
- padding = [-1] * pad_len
- return tuple(list(prompt_mapping) + padding)
-
- def _pad_to_shape(self, src, target_shape, dims=1):
- if dims == 1:
- pad_len = target_shape[0] - src.shape[0]
- return F.pad(src, (0, pad_len), value=0).to(torch.int32)
- else:
- pad_rows = target_shape[0] - src.shape[0]
- pad_cols = target_shape[1] - src.shape[1]
- return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32)
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 2e7267d56d838..559f1a87d9777 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -67,21 +67,15 @@ else:
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
-from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
- rocm_aiter_grouped_topk,
-)
-
-if current_platform.is_tpu():
- from .moe_pallas import fused_moe as fused_moe_pallas
-else:
- fused_moe_pallas = None # type: ignore
-
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
+from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
+ rocm_aiter_grouped_topk,
+)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py
deleted file mode 100644
index 66c00cf89873a..0000000000000
--- a/vllm/model_executor/layers/fused_moe/moe_pallas.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import torch
-import torch.nn.functional as F
-
-
-def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
- """
- Compute the histogram of an int32 tensor. The bin edges are defined by the
- min and max values, with step = 1.
- """
- assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
- assert min <= max, "min must be less than or equal to max."
-
- def searchsorted(
- sorted_sequence: torch.Tensor, values_to_search: torch.Tensor
- ) -> torch.Tensor:
- return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
-
- bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to(
- input.device
- )
- return searchsorted(bin_edges, input).to(torch.int32)
-
-
-def fused_moe(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- gating_output: torch.Tensor,
- topk: int,
- global_num_experts: int,
- expert_map: torch.Tensor = None,
- renormalize: bool = False,
-) -> torch.Tensor:
- """
- Args:
- hidden_states: [*, hidden_size]
- w1: [num_experts, intermediate_size * 2, hidden_size]
- w2: [num_experts, hidden_size, intermediate_size]
- gating_output: [*, num_experts]
- """
- assert expert_map is None, "expert_map is not supported for pallas MoE."
- import torch_xla.experimental.custom_kernel # noqa: F401
-
- orig_shape = hidden_states.shape
- hidden_size = hidden_states.shape[-1]
- num_tokens = hidden_states.shape[:-1].numel()
- num_experts = w1.shape[0]
- intermediate_size = w2.shape[-1]
- device = hidden_states.device
- dtype = hidden_states.dtype
- assert (num_tokens * topk) % 16 == 0, (
- "The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
- f"16 but got {num_tokens * topk}"
- )
-
- hidden_states = hidden_states.view(num_tokens, hidden_size)
- gating_output = gating_output.view(num_tokens, num_experts)
- topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
- topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
- if renormalize:
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
- topk_weights = topk_weights.to(dtype)
-
- topk_indices = topk_indices.flatten()
- topk_argsort_indices = topk_indices.argsort()
- topk_argsort_revert_indices = topk_argsort_indices.argsort()
- token_indices = torch.arange(num_tokens, device=device).repeat_interleave(topk)
- token_indices = token_indices[topk_argsort_indices]
- group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
-
- x = hidden_states[token_indices]
- x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True)
- x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
- x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True)
- x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
-
- x = x * topk_weights.unsqueeze(dim=-1)
- x = x.sum(dim=-2)
- x = x.reshape(orig_shape)
- return x
diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
index 82dbccf3fa9da..4c03cff2e8131 100644
--- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
+++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
@@ -38,10 +38,6 @@ if current_platform.is_cuda_alike():
else:
TritonExperts = None # type: ignore
-if current_platform.is_tpu():
- from .moe_pallas import fused_moe as fused_moe_pallas
-else:
- fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
@@ -403,53 +399,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=layer.custom_routing_function,
)
- def forward_tpu(
- self,
- layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
- x: torch.Tensor,
- router_logits: torch.Tensor,
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- assert not layer.use_grouped_topk
- assert layer.num_expert_group is None
- assert layer.topk_group is None
- assert layer.custom_routing_function is None
- assert layer.apply_router_weight_on_input is False
- if layer.scoring_func != "softmax":
- raise NotImplementedError(
- "Only softmax scoring function is supported for TPU."
- )
- if layer.e_score_correction_bias is not None:
- raise NotImplementedError(
- "Expert score correction bias is not supported for TPU."
- )
- assert layer.activation == "silu", (
- f"{layer.activation} is not supported for TPU."
- )
- assert layer.routed_scaling_factor == 1.0, (
- f"routed_scaling_factor {layer.routed_scaling_factor} is "
- "not supported for TPU."
- )
- if (
- layer.enable_eplb is not False
- or layer.expert_load_view is not None
- or layer.logical_to_physical_map is not None
- or layer.logical_replica_count is not None
- ):
- raise NotImplementedError("Expert load balancing is not supported for TPU.")
- return fused_moe_pallas(
- hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk=layer.top_k,
- gating_output=router_logits,
- global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- renormalize=layer.renormalize,
- )
-
- if current_platform.is_tpu():
- forward_native = forward_tpu
- elif current_platform.is_cpu():
+ if current_platform.is_cpu():
forward_native = forward_cpu
elif current_platform.is_xpu():
forward_native = forward_xpu
diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py
index 18aaae394f935..48db0d1bbbd47 100644
--- a/vllm/model_executor/layers/quantization/__init__.py
+++ b/vllm/model_executor/layers/quantization/__init__.py
@@ -11,7 +11,6 @@ logger = init_logger(__name__)
QuantizationMethods = Literal[
"awq",
"deepspeedfp",
- "tpu_int8",
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
@@ -130,12 +129,10 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .ptpc_fp8 import PTPCFp8Config
from .rtn import RTNConfig
from .torchao import TorchAOConfig
- from .tpu_int8 import Int8TpuConfig
method_to_config: dict[str, type[QuantizationConfig]] = {
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
- "tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"fp_quant": FPQuantConfig,
diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
index 20d050d387d49..4ccc4182367a6 100644
--- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
+++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
@@ -19,9 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel,
)
-from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
- XLAScaledMMLinearKernel,
-)
from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
@@ -29,7 +26,6 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
- PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}
diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
deleted file mode 100644
index 0be858c51993d..0000000000000
--- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import warnings
-
-import torch
-from functorch.experimental.control_flow import cond # noqa: F401
-
-from vllm.model_executor.layers.quantization.utils import replace_parameter
-from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
- convert_to_channelwise,
-)
-from vllm.platforms import current_platform
-
-from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
-
-
-class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
- @classmethod
- def is_supported(
- cls, compute_capability: int | None = None
- ) -> tuple[bool, str | None]:
- if not current_platform.is_tpu():
- return False, "Requires TPU."
- return True, None
-
- @classmethod
- def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
- if not current_platform.is_tpu():
- return False, "ScaledMMXLA requires running on TPU."
-
- if c.is_static_input_scheme:
- return False, "ScaledMMXLA requires dynamic activation scales."
-
- if not c.input_symmetric:
- return False, "ScaledMMXLA requires symmetric activation scales."
-
- if not c.is_channelwise:
- return False, "ScaledMMXLA requires channelwise weight scales"
-
- return True, None
-
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- # WEIGHT
- # [out, in] (different than cutlass_scaled_mm)
- weight = getattr(layer, self.w_q_name)
- replace_parameter(
- layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
- )
-
- # WEIGHT SCALE
- # XLA kernels support only per-tensor and per-channel.
- # If we have a fused module (QKV, MLP) with per tensor scales (thus N
- # scales being passed to the kernel), convert to the per-channel case.
- is_fused_module = len(layer.logical_widths) > 1
- weight_scale = getattr(layer, self.w_s_name)
- if is_fused_module and not self.config.is_channelwise:
- weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
-
- # [out_channel,] (different than cutlass_scaled_mm)
- weight_scale = weight_scale.squeeze(-1)
- replace_parameter(
- layer,
- self.w_s_name,
- torch.nn.Parameter(weight_scale.data, requires_grad=False),
- )
-
- # Only support symmetric dynamic activation quantization.
- setattr(layer, self.i_s_name, None)
- setattr(layer, self.i_zp_name, None)
- setattr(layer, self.azp_adj_name, None)
-
- # Filter warning for cond usage in apply_weights. It is okay
- # to specialize the graph since bias is not dynamic.
- warnings.filterwarnings(
- "ignore",
- message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501
- )
-
- def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
- return x
-
- def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
- return x + bias
-
- def apply_weights(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: torch.Tensor | None = None,
- ) -> torch.Tensor:
- w_q, w_s, _, _, _ = self._get_weight_params(layer)
-
- # Required to register custom ops.
- import torch_xla.experimental.custom_kernel # noqa: F401
-
- out = torch.ops.xla.quantized_matmul_int8(
- x,
- w_q,
- w_s,
- quantize_activation=True,
- )
-
- # Explicitly capture control flow to make dynamo happy.
- # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
- return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py
deleted file mode 100644
index 64bfa8fb80eb2..0000000000000
--- a/vllm/model_executor/layers/quantization/tpu_int8.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from typing import Any, Optional
-
-import torch
-from torch.nn import Module
-from torch.nn.parameter import Parameter
-
-from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
-from vllm.model_executor.layers.quantization import (
- QuantizationConfig,
- QuantizationMethods,
-)
-from vllm.model_executor.parameter import ModelWeightParameter
-
-ACTIVATION_SCHEMES = ["none", "dynamic"]
-
-
-class Int8TpuConfig(QuantizationConfig):
- """Int8 Quantization Config class for TPU Backend."""
-
- def __init__(
- self,
- activation_scheme: str = "none",
- ) -> None:
- super().__init__()
- if activation_scheme not in ACTIVATION_SCHEMES:
- raise ValueError(f"Unsupported activation scheme {activation_scheme}")
- self.activation_scheme = activation_scheme
-
- def get_name(self) -> QuantizationMethods:
- return "tpu_int8"
-
- def get_supported_act_dtypes(self) -> list[torch.dtype]:
- return [torch.float16, torch.bfloat16]
-
- @classmethod
- def get_min_capability(cls) -> int:
- raise NotImplementedError("This function should not be called with TPU Backend")
-
- @staticmethod
- def get_config_filenames() -> list[str]:
- return []
-
- @classmethod
- def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
- activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
- return cls(activation_scheme=activation_scheme)
-
- def get_quant_method(
- self, layer: Module, prefix: str
- ) -> Optional["TPUInt8LinearMethod"]:
- if isinstance(layer, LinearBase):
- return TPUInt8LinearMethod(self)
- return None
-
-
-class TPUInt8LinearMethod(LinearMethodBase):
- """Int8 Linear method for TPU Quant."""
-
- def __init__(self, quant_config: Int8TpuConfig):
- self.quant_config = quant_config
- self.quantize_activation = False
- if self.quant_config.activation_scheme == "dynamic":
- self.quantize_activation = True
-
- def create_weights(
- self,
- layer: Module,
- input_size_per_partition: int,
- output_partition_sizes: list[int],
- input_size: int,
- output_size: int,
- params_dtype: torch.dtype,
- **extra_weight_attrs,
- ):
- weight_loader = extra_weight_attrs.get("weight_loader")
- weight = ModelWeightParameter(
- data=torch.empty(
- sum(output_partition_sizes),
- input_size_per_partition,
- dtype=params_dtype,
- ),
- input_dim=1,
- output_dim=0,
- weight_loader=weight_loader,
- )
- layer.register_parameter("weight", weight)
-
- def _quantize_weight(
- self, weight: torch.Tensor
- ) -> tuple[torch.Tensor, torch.Tensor]:
- weight_dtype = weight.dtype
- weight = weight.cpu().to(torch.float32)
- n_bit = 8
- eps = 1e-5
- max_int = 2 ** (n_bit - 1) - 1
- min_int = -(2 ** (n_bit - 1))
- max_val = weight.abs().amax(dim=-1, keepdim=True)
- max_val = max_val.clamp(min=eps)
- qscale = max_val / max_int
- qweight = torch.clamp(
- torch.round(weight * (1.0 / qscale)), min_int, max_int
- ).to(torch.int8)
- qscale = qscale.squeeze().to(weight_dtype)
- return qweight, qscale
-
- def process_weights_after_loading(self, layer: Module) -> None:
- layer.weight = Parameter(layer.weight.data, requires_grad=False)
- device = layer.weight.device
- qweight, qscale = self._quantize_weight(layer.weight)
- qweight = qweight.to(device)
- qscale = qscale.to(device)
- layer.weight = Parameter(qweight, requires_grad=False)
- layer.scale = Parameter(qscale, requires_grad=False)
-
- def apply(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: torch.Tensor | None = None,
- ) -> torch.Tensor:
- try:
- import torch_xla.experimental.custom_kernel # noqa: F401
- except ImportError as err:
- raise ImportError(
- "Please install torch_xla by following the instructions at "
- "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
- "to run vLLM on TPU."
- ) from err
- weight = layer.weight
- scale = layer.scale
- out = torch.ops.xla.quantized_matmul_int8(
- x, weight, scale, quantize_activation=self.quantize_activation
- )
- if bias is not None:
- out = out + bias
- return out
diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py
index 88c6d1e27e39c..c4e961581ef3f 100644
--- a/vllm/model_executor/model_loader/default_loader.py
+++ b/vllm/model_executor/model_loader/default_loader.py
@@ -30,7 +30,6 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator,
safetensors_weights_iterator,
)
-from vllm.platforms import current_platform
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
logger = init_logger(__name__)
@@ -241,22 +240,6 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.pt_load_map_location,
)
- if current_platform.is_tpu():
- from vllm.platforms.tpu import USE_TPU_INFERENCE
-
- if not USE_TPU_INFERENCE:
- # In PyTorch XLA, we should call `torch_xla.sync`
- # frequently so that not too many ops are accumulated
- # in the XLA program.
- import torch_xla
-
- def _xla_weights_iterator(iterator: Generator):
- for weights in iterator:
- yield weights
- torch_xla.sync(wait=False)
-
- weights_iterator = _xla_weights_iterator(weights_iterator)
-
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.
diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py
deleted file mode 100644
index fc142f1f07fae..0000000000000
--- a/vllm/model_executor/model_loader/tpu.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import time
-
-import torch
-import torch.nn as nn
-import torch_xla.core.xla_model as xm
-import torch_xla.distributed.spmd as xs
-
-from vllm.config import ModelConfig, VllmConfig
-from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
-from vllm.logger import init_logger
-from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
-from vllm.model_executor.model_loader.utils import (
- initialize_model,
- process_weights_after_loading,
-)
-from vllm.utils.torch_utils import set_default_torch_dtype
-
-logger = init_logger(__name__)
-
-
-class TPUModelLoader(DefaultModelLoader):
- """
- A TPU model loader for model loading under SPMD mode.
- """
-
- def load_model(
- self,
- vllm_config: VllmConfig,
- model_config: ModelConfig,
- mesh: xs.Mesh | None = None,
- ) -> nn.Module:
- # Initialize model and load weights on CPU. Then, during SPMD partition,
- # weights are sharded and transferred to TPUs.
- self.counter_before_loading_weights = time.perf_counter()
- model_config = vllm_config.model_config
- assert model_config.quantization is None, "Quantization not supported"
- target_device = torch.device("cpu")
- with set_default_torch_dtype(model_config.dtype):
- with target_device:
- model = initialize_model(vllm_config=vllm_config)
-
- load_format = vllm_config.load_config.load_format
- if load_format != "dummy":
- weights_to_load = {name for name, _ in model.named_parameters()}
- all_weights = self.get_all_weights(model_config, model)
- loaded_weights = model.load_weights(all_weights)
- self.counter_after_loading_weights = time.perf_counter()
- logger.info(
- "Loading weights took %.2f seconds",
- self.counter_after_loading_weights
- - self.counter_before_loading_weights,
- )
- # We only enable strict check for non-quantized models
- # that have loaded weights tracking currently.
- if model_config.quantization is None and loaded_weights is not None:
- weights_not_loaded = weights_to_load - loaded_weights
- if weights_not_loaded:
- raise ValueError(
- "Following weights were not initialized from "
- f"checkpoint: {weights_not_loaded}"
- )
- else:
- logger.info("Use dummy weight during weight loading.")
-
- process_weights_after_loading(model, model_config, target_device)
-
- counter_before_partition = time.perf_counter()
- model = model.eval()
- model = model.to("xla")
- shard_model(model, mesh)
- counter_after_partition = time.perf_counter()
- logger.info(
- "Partition model took %.2f seconds",
- counter_after_partition - counter_before_partition,
- )
-
- # Ensure the model is properly loaded.
- self._check_model_is_loaded(mesh, model)
-
- # Need to torch compile after model sharding are done. Because the
- # compiler hints ('xs.mark_sharding') are torch ops.
- if not model_config.is_multimodal_model:
- model.model = torch.compile(model.model, backend="openxla")
- else:
- model.language_model.model = torch.compile(
- model.language_model.model, backend="openxla"
- )
- return model
-
- def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None:
- """
- Ensure the model is properly loaded.
- 1. All model parameters and buffers are on XLA device.
- 2. Non-SPMD friendly layers are replaced as expected.
- """
- device = xm.xla_device()
- device_type = str(device.type)
-
- # Check parameters
- for name, param in model.named_parameters():
- assert param.device.type == device_type, (
- f"Parameter {name} is on {param.device.type} instead of {device_type}"
- )
-
- # Check buffers
- for name, buffer in model.named_buffers():
- assert buffer.device.type == device_type, (
- f"Buffer {name} is on {buffer.device.type} instead of {device_type}"
- )
-
- for module in model.modules():
- if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"):
- raise AssertionError(
- "QKVParallelLinear should be replaced by \
- XlaQKVParallelLinear under SPMD mode."
- )
diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py
index 7c479bf2b6a0e..455aceb3269eb 100644
--- a/vllm/platforms/tpu.py
+++ b/vllm/platforms/tpu.py
@@ -1,287 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import contextlib
-from typing import TYPE_CHECKING, Optional, cast
-
-import torch
-from tpu_info import device
-
-from vllm.attention.backends.registry import AttentionBackendEnum
-from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
-from .interface import Platform, PlatformEnum
-
-if TYPE_CHECKING:
- from typing import TypeAlias
-
- from vllm.attention.selector import AttentionSelectorConfig
- from vllm.config import VllmConfig
- from vllm.config.cache import BlockSize
- from vllm.pooling_params import PoolingParams
- from vllm.sampling_params import SamplingParams
-
- ParamsType: TypeAlias = SamplingParams | PoolingParams
-else:
- BlockSize = None
- VllmConfig = None
- PoolingParams = None
- ParamsType = None
-
logger = init_logger(__name__)
-USE_TPU_INFERENCE = False
-
-
-class TpuPlatform(Platform):
- _enum = PlatformEnum.TPU
- device_name: str = "tpu"
- device_type: str = "tpu"
- dispatch_key: str = "XLA"
- ray_device_key: str = "TPU"
- dist_backend: str = "gloo"
- device_control_env_var: str = "TPU_VISIBLE_CHIPS"
- simple_compile_backend: str = "openxla"
-
- supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
-
- additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
-
- @classmethod
- def import_kernels(cls) -> None:
- # Do not import vllm._C
- with contextlib.suppress(ImportError):
- import vllm._moe_C # noqa: F401
-
- @classmethod
- def get_attn_backend_cls(
- cls,
- selected_backend: "AttentionBackendEnum",
- attn_selector_config: "AttentionSelectorConfig",
- ) -> str:
- if attn_selector_config.use_sparse:
- raise NotImplementedError("Sparse Attention is not supported on TPU.")
- if selected_backend != AttentionBackendEnum.PALLAS:
- logger.info("Cannot use %s backend on TPU.", selected_backend)
-
- logger.info("Using Pallas V1 backend.")
- return AttentionBackendEnum.PALLAS.get_path()
-
- @classmethod
- def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
- return [
- AttentionBackendEnum.PALLAS,
- ]
-
- @classmethod
- def get_vit_attn_backend(
- cls,
- head_size: int,
- dtype: torch.dtype,
- backend: Optional["AttentionBackendEnum"] = None,
- ) -> "AttentionBackendEnum":
- if backend is not None:
- assert backend in cls.get_supported_vit_attn_backends(), (
- f"Backend {backend} is not supported for vit attention"
- f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
- )
- logger.info_once(f"Using backend {backend} for vit attention.")
- return backend
-
- logger.info_once(
- f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
- )
- return AttentionBackendEnum.PALLAS
-
- @classmethod
- def set_device(cls, device: torch.device) -> None:
- """
- Set the device for the current platform.
- """
- torch.tpu.set_device(device)
-
- @classmethod
- def get_device_name(cls, device_id: int = 0) -> str:
- chip_type, _ = device.get_local_chips()
- return f"TPU {chip_type.name}"
-
- @classmethod
- def get_device_total_memory(cls, device_id: int = 0) -> int:
- raise NotImplementedError
-
- @classmethod
- def get_punica_wrapper(cls) -> str:
- return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
-
- @classmethod
- def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
- return torch.finfo(dtype).min, torch.finfo(dtype).max
-
- @classmethod
- def can_update_inplace(cls):
- return False
-
- @classmethod
- def get_lora_vocab_padding_size(cls) -> int:
- return 1
-
- @classmethod
- def inference_mode(cls):
- return torch.no_grad()
-
- @classmethod
- def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
- from vllm.config import CompilationMode, CUDAGraphMode
-
- cache_config = vllm_config.cache_config
- # For v0, the default block size is 16.
- if cache_config and cache_config.block_size is None:
- cache_config.block_size = cast(BlockSize, 16)
- compilation_config = vllm_config.compilation_config
-
- # TPU only supports DYNAMO_TRACE_ONCE compilation mode
- if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
- logger.info(
- "[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\
- disabling cudagraph."
- )
- compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
-
- if (
- compilation_config.cudagraph_mode is None
- or compilation_config.cudagraph_mode.max_cudagraph_mode()
- != CUDAGraphMode.NONE
- ):
- logger.info(
- "[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
- )
- compilation_config.cudagraph_mode = CUDAGraphMode.NONE
-
- if compilation_config.backend == "":
- compilation_config.backend = "openxla"
-
- assert vllm_config.speculative_config is None, (
- "TPU does not support speculative decoding"
- )
-
- model_config = vllm_config.model_config
- if model_config is not None and model_config.dtype in (
- torch.float16,
- torch.float32,
- ):
- logger.warning(
- "The TPU backend currently does not support %s. "
- "Using bfloat16 instead.",
- model_config.dtype,
- )
- model_config.dtype = torch.bfloat16
-
- from vllm.v1.attention.backends.pallas import PallasAttentionBackend
-
- cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment]
-
- parallel_config = vllm_config.parallel_config
- scheduler_config = vllm_config.scheduler_config
- if parallel_config.worker_cls == "auto":
- parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
-
- assert not vllm_config.speculative_config, (
- "Speculative decoding is not yet supported for TPU backend"
- )
-
- if (
- scheduler_config.is_multimodal_model
- and not scheduler_config.disable_chunked_mm_input
- ):
- logger.warning(
- "TPU does not support running Multimodal models"
- " without setting `--disable_chunked_mm_input`. "
- "Forcing --disable_chunked_mm_input."
- )
- scheduler_config.disable_chunked_mm_input = True
-
- if model_config and model_config.use_mla:
- logger.info(
- "MLA is enabled on a non-GPU platform; forcing chunked "
- "prefill and prefix caching to be disabled."
- )
- vllm_config.scheduler_config.enable_chunked_prefill = False
- vllm_config.scheduler_config.max_num_batched_tokens = max(
- vllm_config.model_config.max_model_len,
- vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
- )
-
- @classmethod
- def is_pin_memory_available(cls):
- logger.warning("Pin memory is not supported on TPU.")
- return False
-
- @classmethod
- def get_device_communicator_cls(cls) -> str:
- return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
-
- @classmethod
- def validate_request(
- cls,
- prompt: PromptType,
- params: ParamsType,
- processed_inputs: ProcessorInputs,
- ) -> None:
- """Raises if this request is unsupported on this platform"""
- from vllm.sampling_params import SamplingParams, SamplingType
-
- if (
- isinstance(params, SamplingParams)
- and params.sampling_type == SamplingType.RANDOM_SEED
- ):
- raise ValueError("Torch XLA does not support per-request seed.")
-
- @classmethod
- @torch.compile(backend="openxla")
- def insert_blocks_to_device(
- cls,
- src_cache: torch.Tensor,
- dst_cache: torch.Tensor,
- src_block_indices: torch.Tensor,
- dst_block_indices: torch.Tensor,
- ) -> None:
- torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
- dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
-
- @classmethod
- @torch.compile(backend="openxla")
- def swap_out_blocks_to_host(
- cls,
- src_cache: torch.Tensor,
- dst_cache: torch.Tensor,
- src_block_indices: torch.Tensor,
- dst_block_indices: torch.Tensor,
- ) -> None:
- """tpu blocks to cpu blocks"""
- torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
- dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
-
- @classmethod
- def use_sync_weight_loader(cls) -> bool:
- return True
-
- @classmethod
- def check_max_model_len(cls, max_model_len: int) -> int:
- """
- Check max_model_len for the current platform.
- """
- logger.warning(
- "--max-model-len is not specified, "
- "it's currently using model's default length %d, "
- "which might be too large."
- "Please input with --max-model-len based on your "
- "request input length and output length, to avoid "
- "unnecessary degradation.",
- max_model_len,
- )
- return max_model_len
-
try:
from tpu_inference.platforms import (
@@ -291,5 +14,7 @@ try:
TpuPlatform = TpuInferencePlatform # type: ignore
USE_TPU_INFERENCE = True
except ImportError:
- logger.info("tpu_inference not found, using vLLM's TpuPlatform")
+ logger.error(
+ "tpu_inference not found, please install tpu_inference to run vllm on TPU"
+ )
pass
diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py
index 69226763aafe6..b0886bba8a22a 100644
--- a/vllm/usage/usage_lib.py
+++ b/vllm/usage/usage_lib.py
@@ -186,20 +186,6 @@ class UsageMessage:
except Exception:
return False
- def _report_torch_xla_usage(self) -> bool:
- try:
- import torch_xla
-
- self.gpu_count = torch_xla.runtime.world_size()
- self.gpu_type = torch_xla.tpu.get_tpu_type()
- self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[
- "bytes_limit"
- ]
- self.cuda_runtime = "torch_xla"
- return True
- except Exception:
- return False
-
def _report_usage_once(
self,
model_architecture: str,
@@ -217,9 +203,7 @@ class UsageMessage:
if current_platform.is_cuda():
self.cuda_runtime = torch.version.cuda
if current_platform.is_tpu(): # noqa: SIM102
- if (not self._report_tpu_inference_usage()) and (
- not self._report_torch_xla_usage()
- ):
+ if not self._report_tpu_inference_usage():
logger.exception("Failed to collect TPU information")
self.provider = _detect_cloud_provider()
self.architecture = platform.machine()
diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py
deleted file mode 100644
index 525026bac5a7e..0000000000000
--- a/vllm/v1/attention/backends/pallas.py
+++ /dev/null
@@ -1,436 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from dataclasses import dataclass
-
-import torch
-
-from vllm.attention.backends.abstract import (
- AttentionBackend,
- AttentionImpl,
- AttentionLayer,
- AttentionType,
-)
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.utils.math_utils import cdiv, next_power_of_2
-
-logger = init_logger(__name__)
-
-# TPU requires the head size to be a multiple of 128.
-TPU_HEAD_SIZE_ALIGNMENT = 128
-
-# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
-# from to fp32 directly. That's why it has a dtype mapping different from GPU
-TPU_STR_DTYPE_TO_TORCH_DTYPE = {
- "half": torch.half,
- "bfloat16": torch.bfloat16,
- "float": torch.float,
- "fp8": torch.float8_e4m3fn,
- "fp8_e4m3": torch.float8_e4m3fn,
- "fp8_e5m2": torch.float8_e5m2,
- "int8": torch.int8,
- "uint8": torch.uint8,
-}
-
-try:
- import tpu_inference # noqa: F401
-except ImportError:
- # Lazy import torch_xla
- import torch_xla.core.xla_builder as xb
- import torch_xla.experimental.custom_kernel # noqa: F401
- from torch.library import impl
- from torch_xla._internal.jax_workarounds import requires_jax
- from torch_xla.experimental.custom_kernel import XLA_LIB
-
- @requires_jax
- def kv_cache_update_op_impl(
- kv: torch.Tensor,
- slot_mapping: torch.Tensor,
- kv_cache: torch.Tensor,
- num_kv_update_slices: torch.Tensor,
- page_size: int,
- num_slices_per_block: int,
- ):
- from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
-
- new_kv_cache = xb.call_jax(
- kv_cache_update,
- (kv, slot_mapping, kv_cache, num_kv_update_slices),
- {"page_size": page_size, "num_slices_per_block": num_slices_per_block},
- )
- return new_kv_cache
-
- XLA_LIB.define(
- "kv_cache_update_op(Tensor kv, Tensor slot_mapping,"
- "Tensor kv_cache, Tensor num_kv_update_slices, int page_size,"
- "int num_slices_per_block)"
- "-> Tensor",
- )
-
- @impl(XLA_LIB, "kv_cache_update_op", "XLA")
- def kv_cache_update_op_xla(
- kv: torch.Tensor,
- slot_mapping: torch.Tensor,
- kv_cache: torch.Tensor,
- num_kv_update_slices: torch.Tensor,
- page_size: int,
- num_slices_per_block: int,
- ) -> torch.Tensor:
- new_kv_cache = kv_cache_update_op_impl(
- kv,
- slot_mapping,
- kv_cache,
- num_kv_update_slices,
- page_size,
- num_slices_per_block,
- )
- return new_kv_cache
-
- @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
- def kv_cache_update_op_non_xla(
- kv: torch.Tensor,
- slot_mapping: torch.Tensor,
- kv_cache: torch.Tensor,
- num_kv_update_slices: torch.Tensor,
- page_size: int,
- num_slices_per_block: int,
- ) -> torch.Tensor:
- return kv_cache
-
-
-class PallasAttentionBackend(AttentionBackend):
- @staticmethod
- def get_name() -> str:
- return "PALLAS"
-
- @staticmethod
- def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
- return PallasAttentionBackendImpl
-
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- padded_head_size = (
- cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
- )
- return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
-
- @staticmethod
- def swap_blocks(
- src_kv_cache: torch.Tensor,
- dst_kv_cache: torch.Tensor,
- src_to_dst: torch.Tensor,
- ) -> None:
- raise RuntimeError("swap_blocks is not used for the TPU backend.")
-
- # In recent TPU generations, up to v6e, the SMEM size is 1MB. The
- # block_tables within the PallasMetadata constitute almost the entire SMEM
- # requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
- # we simply make sure that the size is smaller than half of SMEM capacity.
- @staticmethod
- def get_min_page_size(vllm_config: VllmConfig) -> int:
- max_num_page_per_req = (
- 1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4
- )
- min_page_size = cdiv(
- vllm_config.model_config.max_model_len, max_num_page_per_req
- )
- min_page_size = 1 << (min_page_size - 1).bit_length()
- return min_page_size
-
- @staticmethod
- def get_max_num_seqs(model_len: int, page_size: int) -> int:
- num_page_per_req = cdiv(model_len, page_size)
- return 1024 * 1024 // 2 // num_page_per_req // 4
-
- # TPU has limited SREGs (scalar registers), if page_size is too small, we
- # can spill SREGs easily which leads to bad performance. The strategy we
- # apply here is trying to split max-model-len to 16 pages which make the
- # spill less likely. Meanwhile we make sure the page size is in [16, 256].
- @staticmethod
- def get_page_size(vllm_config: VllmConfig) -> int:
- # TODO: This is a temporary fix for vmem OOM.
- # For long model length, we use 16 page-size to avoid too much
- # VMEM spill. A more robust solution should be implemented to
- # handle VREG spills.
- if vllm_config.model_config.max_model_len > 8192:
- return 16
- page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16
- if page_size <= 16:
- return 16
- if page_size >= 256:
- return 256
- return page_size
-
-
-@dataclass
-class PallasMetadata:
- # NOTE(sang): Definition of context_len, query_len, and seq_len.
- # |---------- N-1 iteration --------|
- # |---------------- N iteration ---------------------|
- # |- tokenA -|......................|-- newTokens ---|
- # |---------- context_len ----------|
- # |-------------------- seq_len ---------------------|
- # |-- query_len ---|
-
- # Used in the PallasAttentionBackendImpl
- slot_mapping: torch.Tensor
- block_tables: torch.Tensor
- context_lens: torch.Tensor
- query_start_loc: torch.Tensor
- num_seqs: torch.Tensor
- num_kv_update_slices: torch.Tensor
- num_slices_per_kv_cache_update_block: int
-
-
-class PallasAttentionBackendImpl(AttentionImpl):
- def __init__(
- self,
- num_heads: int,
- head_size: int,
- scale: float,
- num_kv_heads: int,
- alibi_slopes: list[float] | None,
- sliding_window: int | None,
- kv_cache_dtype: str,
- logits_soft_cap: float | None = None,
- attn_type: str = AttentionType.DECODER,
- kv_sharing_target_layer_name: int | None = None,
- ) -> None:
- self.num_heads = num_heads
- self.head_size = head_size
- self.scale = float(scale)
- self.num_kv_heads = num_kv_heads
- self.sliding_window = sliding_window
- self.logits_soft_cap = logits_soft_cap
- self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
-
- self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- if alibi_slopes is not None:
- raise NotImplementedError("Alibi slopes is not supported.")
-
- if attn_type != AttentionType.DECODER:
- raise NotImplementedError(
- "Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "PallasAttentionBackendImpl"
- )
-
- self.kv_cache_quantized_dtype = None
- if kv_cache_dtype != "auto":
- self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
- kv_cache_dtype.lower().strip()
- )
-
- def forward(
- self,
- layer: AttentionLayer,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: PallasMetadata,
- output: torch.Tensor | None = None,
- output_scale: torch.Tensor | None = None,
- output_block_scale: torch.Tensor | None = None,
- ) -> torch.Tensor:
- """Forward pass with Pallas attention.
-
- Args:
- query: shape = [num_tokens, num_heads * head_size]
- key: shape = [num_tokens, num_kv_heads * head_size]
- value: shape = [num_tokens, num_kv_heads * head_size]
- kv_cache: shape =
- [num_blocks, block_size, num_kv_heads * 2, head_size]
- attn_metadata: Metadata for attention.
- Returns:
- shape = [num_tokens, num_heads * head_size]
- """
- if output_scale is not None or output_block_scale is not None:
- raise NotImplementedError(
- "fused output quantization is not yet supported"
- " for PallasAttentionBackendImpl"
- )
-
- # For determine_available_memory case.
- if kv_cache.numel() == 0:
- if output is None:
- output = torch.ones_like(query)
- return output
-
- num_tokens, hidden_size = query.shape
- query = query.view(num_tokens, self.num_heads, self.head_size)
- key = key.view(-1, self.num_kv_heads, self.head_size)
- value = value.view(-1, self.num_kv_heads, self.head_size)
- if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
- padded_head_size = (
- cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
- )
- query = torch.nn.functional.pad(
- query, (0, padded_head_size - self.head_size), value=0.0
- )
- key = torch.nn.functional.pad(
- key, (0, padded_head_size - self.head_size), value=0.0
- )
- value = torch.nn.functional.pad(
- value, (0, padded_head_size - self.head_size), value=0.0
- )
-
- if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
- # Write input keys and values to the KV cache.
- # Skip this if sharing KV cache with an earlier attention layer.
- slot_mapping = attn_metadata.slot_mapping
- write_to_kv_cache(
- key,
- value,
- kv_cache,
- slot_mapping,
- attn_metadata.num_slices_per_kv_cache_update_block,
- attn_metadata.num_kv_update_slices,
- self.kv_cache_quantized_dtype,
- layer._k_scale_float,
- layer._v_scale_float,
- )
-
- if self.kv_cache_quantized_dtype is not None and (
- layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0
- ):
- raise ValueError("k_scale_float and v_scale_float must be non-zero")
- output = torch.ops.xla.ragged_paged_attention(
- query,
- kv_cache,
- attn_metadata.context_lens,
- attn_metadata.block_tables,
- attn_metadata.query_start_loc,
- attn_metadata.num_seqs,
- # By default, the system utilizes optimized block size and
- # vmem_limit_bytes parameters from the kernel repository. However,
- # these can be manually adjusted for debugging if necessary.
- num_kv_pages_per_block=None,
- num_queries_per_block=None,
- vmem_limit_bytes=None,
- use_kernel=True,
- sm_scale=self.scale,
- sliding_window=self.sliding_window,
- soft_cap=self.logits_soft_cap,
- k_scale=layer._k_scale_float,
- v_scale=layer._v_scale_float,
- )
-
- if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
- output = output[:, :, : self.head_size]
-
- return output.reshape(num_tokens, hidden_size)
-
-
-def write_to_kv_cache(
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: torch.Tensor,
- slot_mapping: torch.Tensor,
- num_slices_per_kv_cache_update_block: int,
- num_kv_update_slices: torch.Tensor,
- kv_cache_quantized_dtype: torch.dtype | None = None,
- k_scale: float = 1.0,
- v_scale: float = 1.0,
-) -> None:
- """Write the key and values to the KV cache.
-
- Args:
- key: shape = [num_tokens, num_kv_heads, head_size]
- value: shape = [num_tokens, num_kv_heads, head_size]
- kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
- num_slices_per_kv_cache_update_block: int
- """
- _, page_size, num_combined_kv_heads, head_size = kv_cache.shape
- head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
-
- if kv_cache_quantized_dtype is not None:
- dtype_info = torch.finfo(kv_cache_quantized_dtype)
- key = key.to(torch.float32) / k_scale
- # NOTE: clamp is added here to avoid out of range of quantized dtype
- key = torch.clamp(key, dtype_info.min, dtype_info.max)
- key = key.to(kv_cache_quantized_dtype)
- value = value.to(torch.float32) / v_scale
- value = torch.clamp(value, dtype_info.min, dtype_info.max)
- value = value.to(kv_cache_quantized_dtype)
-
- kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size)
-
- torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
-
- kv_cache = kv_cache.flatten(0, 1)
- new_kv_cache = torch.ops.xla.kv_cache_update_op(
- kv,
- slot_mapping,
- kv_cache,
- num_kv_update_slices,
- page_size,
- num_slices_per_kv_cache_update_block,
- )
- # NOTE: the in-place copy will be optimized away by XLA compiler.
- kv_cache.copy_(new_kv_cache)
-
-
-# We can move this function to a common utils file if it's also useful for other
-# hardware.
-def dtype_bits(dtype: torch.dtype):
- if dtype.is_floating_point:
- try:
- return torch.finfo(dtype).bits
- except TypeError:
- pass
- elif dtype.is_complex:
- if dtype is torch.complex32:
- return 32
- elif dtype is torch.complex64:
- return 64
- elif dtype is torch.complex128:
- return 128
- else:
- try:
- return torch.iinfo(dtype).bits
- # torch.iinfo cannot support int4, int2, bits8...
- except TypeError:
- pass
- str_dtype = str(dtype)
- # support torch.int4, torch.int5, torch.uint5...
- if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
- return int(str_dtype[-1])
- raise TypeError(f"Getting the bit width of {dtype} is not supported")
-
-
-def get_dtype_packing(dtype):
- bits = dtype_bits(dtype)
- if 32 % bits != 0:
- raise ValueError(
- f"The bit width must be divisible by 32, but got bits={bits}, "
- "dtype={dtype}"
- )
- return 32 // bits
-
-
-def get_page_size_bytes(
- block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype
-) -> int:
- """Returns the size in bytes of one page of the KV cache."""
- padded_head_size = (
- cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
- )
- num_combined_kv_heads = num_kv_heads * 2
-
- # NOTE: for the implicit padding in XLA
- packing = get_dtype_packing(kv_cache_dtype)
- num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
-
- kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
- return (
- block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8
- )
diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py
deleted file mode 100644
index c7404c4642d7e..0000000000000
--- a/vllm/v1/worker/tpu_model_runner.py
+++ /dev/null
@@ -1,2191 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import bisect
-import gc
-import time
-from typing import TYPE_CHECKING, Any, cast
-from unittest.mock import patch
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-# TPU XLA related
-import torch_xla
-import torch_xla.core.xla_model as xm
-import torch_xla.distributed.spmd as xs
-import torch_xla.runtime as xr
-
-import vllm.envs as envs
-from vllm.attention.backends.abstract import AttentionType
-from vllm.attention.layer import Attention, MLAAttention
-from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
-from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
-from vllm.config import (
- ParallelConfig,
- VllmConfig,
- get_layers_from_vllm_config,
- update_config,
-)
-from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
-from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
-from vllm.forward_context import set_forward_context
-from vllm.logger import init_logger
-from vllm.lora.layers import BaseLayerWithLoRA
-from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
-from vllm.model_executor.model_loader import get_model_loader
-from vllm.model_executor.model_loader.tpu import TPUModelLoader
-from vllm.model_executor.models.interfaces import (
- SupportsMultiModal,
- supports_transcription,
-)
-from vllm.model_executor.models.interfaces_base import (
- is_pooling_model,
- is_text_generation_model,
-)
-from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import (
- BatchedTensorInputs,
- MultiModalKwargsItem,
- PlaceholderRange,
-)
-from vllm.multimodal.utils import group_mm_kwargs_by_modality
-from vllm.sequence import IntermediateTensors
-from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
-from vllm.utils.math_utils import cdiv, prev_power_of_2
-from vllm.utils.platform_utils import is_pin_memory_available
-from vllm.v1.attention.backends.pallas import (
- TPU_STR_DTYPE_TO_TORCH_DTYPE,
- PallasAttentionBackend,
- PallasMetadata,
- get_page_size_bytes,
-)
-from vllm.v1.kv_cache_interface import (
- AttentionSpec,
- FullAttentionSpec,
- KVCacheConfig,
- KVCacheSpec,
- MLAAttentionSpec,
- SlidingWindowSpec,
-)
-from vllm.v1.outputs import (
- EMPTY_MODEL_RUNNER_OUTPUT,
- LogprobsLists,
- LogprobsTensors,
- ModelRunnerOutput,
-)
-from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
-from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
-from vllm.v1.worker.kv_connector_model_runner_mixin import (
- KVConnectorModelRunnerMixin,
- KVConnectorOutput,
-)
-from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
-from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
-
-from .utils import (
- MultiModalBudget,
- add_kv_sharing_layers_to_kv_cache_groups,
- bind_kv_cache,
- sanity_check_mm_encoder_outputs,
-)
-
-if TYPE_CHECKING:
- from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
-
-logger = init_logger(__name__)
-
-INVALID_TOKEN_ID = -1
-# Smallest output size
-MIN_NUM_SEQS = 8
-
-
-#########################################################
-# Ways to avoid recompilation
-#########################################################
-#
-# The model executor has two primary components:
-# 1. preparing the model and sampler inputs
-# 2. executing the model and sampler.
-# The core idea is to avoid any TPU computation during input preparation. For
-# better compilation tracking and increased flexibility, the model execution and
-# sampler are divided into several distinct components.
-#
-# Below are the detailed steps:
-#
-# Step 1
-# It is recommended to avoid TPU operations when preparing the model and sampler
-# inputs. CPU tensors can be prepared and transferred to the XLA device using
-# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids
-# compilation.
-#
-# Step 2
-# The TPU execution should be decomposed into subgraphs (4 at the moment):
-# 1. the main model
-# 2. selecting hidden states for each request
-# 3. sampler
-# 4. encoder.
-# Each subgraph should be decorated in a torch.compile. This is used to make
-# sure that we have the same subgraph topology in both dummy_run and
-# xecute_model. The results from these subgraphs should either be passed to
-# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for
-# subsequent processing on the CPU.
-#
-# Step 3
-# The dummy_run should be comprehensive, ensuring all potential input shapes and
-# branch predictions are included as subgraph inputs to facilitate
-# pre-compilation.
-class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
- def __init__(
- self,
- vllm_config: VllmConfig,
- device: torch.device,
- original_parallel_config: ParallelConfig | None = None,
- ):
- self.vllm_config = vllm_config
- self.model_config = vllm_config.model_config
- self.cache_config = vllm_config.cache_config
- self.lora_config = vllm_config.lora_config
- self.load_config = vllm_config.load_config
- self.parallel_config = vllm_config.parallel_config
- self.original_parallel_config = original_parallel_config
- self.scheduler_config = vllm_config.scheduler_config
- self.speculative_config = vllm_config.speculative_config
- self.observability_config = vllm_config.observability_config
- self.device_config = vllm_config.device_config
-
- model_config = self.model_config
- cache_config = self.cache_config
- scheduler_config = self.scheduler_config
- parallel_config = self.parallel_config
- self.device = device
- self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
-
- # SPMD Related
- self.use_spmd = envs.VLLM_XLA_USE_SPMD
- if self.use_spmd:
- num_devices = xr.global_runtime_device_count()
- mesh_shape = (num_devices, 1)
- device_ids = np.array(range(num_devices))
- self.mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
-
- self.enforce_eager = model_config.enforce_eager
-
- self.num_xla_graphs = 0
- self._update_num_xla_graphs("init")
-
- self.pin_memory = is_pin_memory_available()
- self.dtype = self.model_config.dtype
- if cache_config.cache_dtype == "auto":
- model_dtype = self.dtype
- if isinstance(model_dtype, str):
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
- else:
- self.kv_cache_dtype = model_dtype
- else:
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
- self._hidden_states_dtype = self.dtype
-
- self.sliding_window = model_config.get_sliding_window()
- self.block_size = cache_config.block_size
- self.max_model_len = model_config.max_model_len
- self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN
- self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
- self.num_blocks_per_most_len_req = (
- cdiv(self.most_model_len, self.block_size)
- if self.most_model_len is not None
- else None
- )
- # InputBatch needs to work with sampling tensors greater than padding
- # to avoid dynamic shapes. Also, avoid suboptimal alignment.
- self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
- self.num_tokens_paddings = _get_token_paddings(
- min_token_size=16,
- max_token_size=scheduler_config.max_num_batched_tokens,
- padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP,
- )
- # In case `max_num_tokens < max(num_tokens_paddings)` use the actual
- # padded max value to pre-allocate data structures and pre-compile.
- self.max_num_tokens = self.num_tokens_paddings[-1]
-
- # Model-related.
- self.num_attn_layers = model_config.get_num_layers_by_block_type(
- parallel_config, "attention"
- )
- self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
- self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
- self.head_size = model_config.get_head_size()
- self.inputs_embeds_size = model_config.get_inputs_embeds_size()
- self.vocab_size = model_config.get_vocab_size()
-
- # Multi-modal data support
- self.mm_registry = MULTIMODAL_REGISTRY
- self.uses_mrope = model_config.uses_mrope
- self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
- model_config
- )
- # TODO: Support M-RoPE (e.g, Qwen2-VL)
- assert not self.uses_mrope, "TPU does not support M-RoPE yet."
-
- self._num_slices_per_kv_cache_update_block = (
- _get_num_slices_per_kv_cache_update_block(
- get_page_size_bytes(
- block_size=self.block_size,
- num_kv_heads=self.num_kv_heads,
- head_size=self.head_size,
- kv_cache_dtype=self.kv_cache_dtype,
- )
- )
- )
-
- # Lazy initialization
- self.model: nn.Module # Set after load_model
- self.kv_caches: list[torch.Tensor] = []
- # mm_hash -> encoder_output
- self.encoder_cache: dict[str, torch.Tensor] = {}
-
- # Request states.
- self.requests: dict[str, CachedRequestState] = {}
- # NOTE(rob): num_prompt_logprobs only includes reqs
- # that are currently in the prefill phase.
- self.num_prompt_logprobs: dict[str, int] = {}
-
- # Initialize input batch early to avoid AttributeError in _update_states
- self.input_batch = InputBatch(
- max_num_reqs=self.max_num_reqs,
- max_model_len=self.max_model_len,
- max_num_batched_tokens=self.max_num_tokens,
- device=self.device,
- pin_memory=self.pin_memory,
- vocab_size=self.model_config.get_vocab_size(),
- block_sizes=[self.block_size],
- kernel_block_sizes=[self.cache_config.block_size],
- )
-
- # Cached torch/numpy tensor
- # The pytorch tensor and numpy array share the same buffer.
- # Sometimes the numpy op is faster so we create both.
- self.input_ids_cpu = torch.zeros(
- self.max_num_tokens, dtype=torch.int32, device="cpu"
- )
-
- self.positions_cpu = torch.zeros(
- self.max_num_tokens, dtype=torch.int32, device="cpu"
- )
- self.positions_np = self.positions_cpu.numpy()
- self.block_table_cpu = torch.zeros(
- (self.max_num_reqs, self.max_num_blocks_per_req),
- dtype=torch.int32,
- device="cpu",
- )
- # adjust num_reqs to avoid SMEM OOM.
- self.num_reqs_most_model_len = (
- min(
- PallasAttentionBackend.get_max_num_seqs(
- self.most_model_len, self.block_size
- ),
- self.max_num_reqs,
- )
- if self.most_model_len is not None
- else None
- )
- self.num_reqs_max_model_len = min(
- PallasAttentionBackend.get_max_num_seqs(
- self.max_model_len, self.block_size
- ),
- self.max_num_reqs,
- )
- self.query_start_loc_cpu = torch.zeros(
- self.max_num_tokens + 1,
- dtype=torch.int32,
- device="cpu",
- pin_memory=self.pin_memory,
- )
- self.query_start_loc_np = self.query_start_loc_cpu.numpy()
-
- self.seq_lens_cpu = torch.zeros(
- self.max_num_tokens,
- dtype=torch.int32,
- device="cpu",
- pin_memory=self.pin_memory,
- )
- self.seq_lens_np = self.seq_lens_cpu.numpy()
-
- # Only relevant for multimodal models
- if self.supports_mm_inputs:
- self.is_mm_embed_cpu = torch.zeros(
- self.max_num_tokens,
- dtype=torch.bool,
- device="cpu",
- pin_memory=self.pin_memory,
- )
-
- # Range tensor with values [0 .. self.max_num_tokens - 1].
- # Used to initialize positions / context_lens / seq_lens
- # Keep in int64 to avoid overflow with long context
- self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64)
- self.num_reqs_paddings = _get_req_paddings(
- min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs
- )
-
- # Layer pairings for cross-layer KV sharing.
- # If an Attention layer `layer_name` is in the keys of this dict, it
- # means this layer will perform attention using the keys and values
- # from the KV cache of `shared_kv_cache_layers[layer_name]`.
- self.shared_kv_cache_layers: dict[str, str] = {}
-
- # tensors for structured decoding
- self.grammar_bitmask_cpu = torch.zeros(
- (self.max_num_reqs, cdiv(self.vocab_size, 32)),
- dtype=torch.int32,
- device="cpu",
- pin_memory=self.pin_memory,
- )
- self.require_structured_out_cpu = torch.zeros(
- (self.max_num_reqs, 1),
- dtype=torch.bool,
- device="cpu",
- pin_memory=self.pin_memory,
- )
- self.structured_decode_arange = torch.arange(
- 0, 32, device="cpu", pin_memory=self.pin_memory
- )
-
- self.mm_budget = (
- MultiModalBudget(
- self.model_config,
- self.scheduler_config,
- self.mm_registry,
- )
- if self.supports_mm_inputs
- else None
- )
-
- if not self.use_spmd:
- self.sample_from_logits_func = torch.compile(
- self.sample_from_logits,
- backend="openxla",
- fullgraph=True,
- dynamic=False,
- )
- else:
- self.sample_from_logits_func = self.sample_from_logits
-
- # For passing scheduler_output between successive
- # execute_model() and sample_tokens() calls.
- self.scheduler_output: SchedulerOutput | None = None
- self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
-
- def reset_mm_cache(self) -> None:
- if self.mm_budget:
- self.mm_budget.reset_cache()
-
- def _update_num_xla_graphs(self, case_str):
- check_comp = self.check_recompilation and not self.enforce_eager
- if not check_comp:
- return
-
- total_cached_graphs = xr.get_num_cached_compilation_graph()
- new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
- if new_compiled_graphs == 0:
- return
-
- logger.info(
- "Add new %d compiled XLA graphs due to %s", new_compiled_graphs, case_str
- )
- self.num_xla_graphs += new_compiled_graphs
-
- def _verify_num_xla_graphs(self, case_str):
- check_comp = self.check_recompilation and not self.enforce_eager
- if not check_comp:
- return
-
- curr_cached_graph = xr.get_num_cached_compilation_graph()
- assert self.num_xla_graphs == curr_cached_graph, (
- "Recompilation after warm up is detected during {}."
- " num_xla_graphs = {} curr_cached_graph = {}".format(
- case_str, self.num_xla_graphs, curr_cached_graph
- )
- )
-
- def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
- """Update the cached states and the persistent batch with the scheduler
- output.
-
- The updated states are used by the `_prepare_inputs` function to create
- the input GPU tensors for the model.
-
- Returns:
- True if there is a new/resumed/paused/finished request.
- If False, we can skip copying SamplingMetadata to the GPU.
- """
- # Remove finished requests from the cached states.
- for req_id in scheduler_output.finished_req_ids:
- self.requests.pop(req_id, None)
- self.num_prompt_logprobs.pop(req_id, None)
-
- # Remove the finished requests from the persistent batch.
- # NOTE(woosuk): There could be an edge case where finished_req_ids and
- # scheduled_req_ids overlap. This happens when a request is aborted and
- # then resubmitted with the same ID. In this case, we treat them as two
- # distinct requests - clearing the cached states for the first request
- # and handling the second as a new request.
- removed_req_indices: list[int] = []
- for req_id in scheduler_output.finished_req_ids:
- req_index = self.input_batch.remove_request(req_id)
- if req_index is not None:
- removed_req_indices.append(req_index)
-
- # Free the cached encoder outputs.
- for mm_hash in scheduler_output.free_encoder_mm_hashes:
- self.encoder_cache.pop(mm_hash, None)
-
- # Remove the unscheduled requests from the persistent batch.
- # NOTE(woosuk): The unscheduled requests are either preempted requests
- # or running requests that are not scheduled in this step. We remove
- # them from the persistent batch but keep their cached states since
- # they will be scheduled again sometime in the future.
- scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
- cached_req_ids = self.input_batch.req_id_to_index.keys()
- unscheduled_req_ids = cached_req_ids - scheduled_req_ids
- # NOTE(woosuk): The persistent batch optimization assumes that
- # consecutive batches contain mostly the same requests. If batches
- # have low request overlap (e.g., alternating between two distinct
- # sets of requests), this optimization becomes very inefficient.
- for req_id in unscheduled_req_ids:
- req_index = self.input_batch.remove_request(req_id)
- assert req_index is not None
- removed_req_indices.append(req_index)
-
- req_ids_to_add: list[str] = []
- # Add new requests to the cached states.
- for new_req_data in scheduler_output.scheduled_new_reqs:
- assert new_req_data.sampling_params is not None, (
- "Pooling is not supported in TPU yet"
- )
- req_id = new_req_data.req_id
- sampling_params = new_req_data.sampling_params
-
- self.requests[req_id] = CachedRequestState(
- req_id=req_id,
- prompt_token_ids=new_req_data.prompt_token_ids,
- prompt_embeds=new_req_data.prompt_embeds,
- mm_features=new_req_data.mm_features,
- sampling_params=sampling_params,
- pooling_params=None,
- generator=None,
- block_ids=new_req_data.block_ids,
- num_computed_tokens=new_req_data.num_computed_tokens,
- output_token_ids=[],
- lora_request=new_req_data.lora_request,
- )
-
- if sampling_params and sampling_params.prompt_logprobs is not None:
- self.num_prompt_logprobs[req_id] = (
- self.input_batch.vocab_size
- if sampling_params.prompt_logprobs == -1
- else sampling_params.prompt_logprobs
- )
-
- req_ids_to_add.append(req_id)
-
- # Update the states of the running/resumed requests.
- req_data = scheduler_output.scheduled_cached_reqs
- for i, req_id in enumerate(req_data.req_ids):
- req_state = self.requests[req_id]
- num_computed_tokens = req_data.num_computed_tokens[i]
- new_block_ids = req_data.new_block_ids[i]
- resumed_from_preemption = req_id in req_data.resumed_req_ids
-
- # Update the cached states.
- req_state.num_computed_tokens = num_computed_tokens
- if not resumed_from_preemption:
- if new_block_ids is not None:
- # Append the new blocks to the existing block IDs.
- for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
- block_ids.extend(new_ids)
- else:
- assert new_block_ids is not None
- # The request is resumed from preemption.
- # Replace the existing block IDs with the new ones.
- req_state.block_ids = new_block_ids
-
- req_index = self.input_batch.req_id_to_index.get(req_id)
- if req_index is None:
- # The request is not in the persistent batch.
- # The request was either preempted and resumed later, or was not
- # scheduled in the previous step and needs to be added again.
- req_ids_to_add.append(req_id)
- continue
-
- # Update the persistent batch.
- self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
- if new_block_ids is not None:
- self.input_batch.block_table.append_row(new_block_ids, req_index)
-
- # Add the new or resumed requests to the persistent batch.
- # The smaller empty indices are filled first.
- removed_req_indices = sorted(removed_req_indices, reverse=True)
- for req_id in req_ids_to_add:
- req_state = self.requests[req_id]
- # Fill the empty index or append to the end
- req_index = removed_req_indices.pop() if removed_req_indices else None
- self.input_batch.add_request(req_state, req_index)
-
- # Condense the batched states if there are empty indices.
- if removed_req_indices:
- self.input_batch.condense(removed_req_indices)
-
- return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
-
- def get_model(self) -> nn.Module:
- return self.model
-
- def get_supported_generation_tasks(self) -> list[GenerationTask]:
- model = self.get_model()
- supported_tasks = list[GenerationTask]()
-
- if is_text_generation_model(model):
- supported_tasks.append("generate")
-
- if supports_transcription(model):
- if model.supports_transcription_only:
- return ["transcription"]
-
- supported_tasks.append("transcription")
-
- return supported_tasks
-
- def get_supported_pooling_tasks(self) -> list[PoolingTask]:
- model = self.get_model()
- if not is_pooling_model(model):
- return []
-
- return list(model.pooler.get_supported_tasks())
-
- def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
- tasks = list[SupportedTask]()
-
- if self.model_config.runner_type == "generate":
- tasks.extend(self.get_supported_generation_tasks())
- if self.model_config.runner_type == "pooling":
- tasks.extend(self.get_supported_pooling_tasks())
-
- return tuple(tasks)
-
- def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
- """
- Generates the KVCacheSpec by parsing the kv cache format from each
- Attention module in the static forward context.
- Returns:
- KVCacheSpec: A dictionary mapping layer names to their KV cache
- format. Layers that do not need KV cache are not included.
- """
-
- layers = get_layers_from_vllm_config(
- self.vllm_config,
- AttentionLayerBase, # type: ignore[type-abstract]
- )
- block_size = self.vllm_config.cache_config.block_size
- cache_dtype_str = self.vllm_config.cache_config.cache_dtype
-
- kv_cache_spec: dict[str, KVCacheSpec] = {}
- for layer_name, attn_module in layers.items():
- # Classic Attention path
- if isinstance(attn_module, Attention):
- if (
- kv_tgt_layer := attn_module.kv_sharing_target_layer_name
- ) is not None:
- # The layer doesn't need its own KV cache and will use that of
- # the target layer. We skip creating a KVCacheSpec for it, so
- # that KV cache management logic will act as this layer does
- # not exist, and doesn't allocate KV cache for the layer. This
- # enables the memory saving of cross-layer kv sharing, allowing
- # a given amount of memory to accommodate longer context lengths
- # or enable more requests to be processed simultaneously.
- self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
- continue
-
- if attn_module.attn_type == AttentionType.DECODER:
- if isinstance(attn_module, ChunkedLocalAttention):
- logger.warning_once(
- "Using irope in Pallas is not supported yet, it "
- "will fall back to global attention for long context."
- )
- if attn_module.sliding_window is not None:
- kv_cache_spec[layer_name] = SlidingWindowSpec(
- block_size=block_size,
- num_kv_heads=attn_module.num_kv_heads,
- head_size=attn_module.head_size,
- dtype=self.kv_cache_dtype,
- sliding_window=attn_module.sliding_window,
- )
- else:
- kv_cache_spec[layer_name] = FullAttentionSpec(
- block_size=block_size,
- num_kv_heads=attn_module.num_kv_heads,
- head_size=attn_module.head_size,
- dtype=self.kv_cache_dtype,
- )
- elif attn_module.attn_type in (
- AttentionType.ENCODER,
- AttentionType.ENCODER_ONLY,
- ):
- # encoder-only attention does not need KV cache.
- continue
- elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
- raise NotImplementedError
- else:
- raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
- # MLAAttention path
- elif isinstance(attn_module, MLAAttention):
- if layer_name in kv_cache_spec:
- continue
- kv_cache_spec[layer_name] = MLAAttentionSpec(
- block_size=block_size,
- num_kv_heads=1,
- head_size=attn_module.head_size,
- dtype=self.kv_cache_dtype,
- cache_dtype_str=cache_dtype_str,
- )
- else:
- continue
-
- return kv_cache_spec
-
- def _get_slot_mapping_metadata(
- self, num_reqs, num_scheduled_tokens_per_req
- ) -> np.ndarray:
- """
- Computes metadata for mapping slots to blocks in the key-value (KV)
- cache for a batch of requests.
-
- This function determines, for each request in the batch, how the
- scheduled tokens are distributed across memory blocks, and generates
- metadata needed to map slices of tokens to their corresponding positions
- in the KV cache.
-
- Args:
- num_reqs (int): Number of requests in the current batch.
- num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
- to be scheduled for each request.
-
- Returns:
- np.ndarray: A 2D array of shape (total_block_len, 3), where each row
- contains:
- - kv_cache_start_index (int): The starting index in the KV cache
- for the corresponding slice.
- - new_kv_start_index (int): The starting index in the new KV
- cache for the corresponding slice.
- - slice_len (int): The length of the slice.
- """
- slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs]
- slices_end = (
- self.input_batch.num_computed_tokens_cpu[:num_reqs]
- + num_scheduled_tokens_per_req
- )
- local_block_start_idx = slices_start // self.block_size
- local_block_end_idx = (slices_end - 1) // self.block_size
- no_repeat_req_indices = self.arange_np[:num_reqs]
- global_block_start_idx = (
- no_repeat_req_indices * self.max_num_blocks_per_req + local_block_start_idx
- )
- block_lens = local_block_end_idx - local_block_start_idx + 1
- global_block_start_idx = np.repeat(global_block_start_idx, block_lens)
- slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens])
- global_block_indices = global_block_start_idx + slice_arange
- block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
- block_numbers = block_table_cpu.flatten()[global_block_indices].numpy()
- total_block_len = np.sum(block_lens)
- slot_mapping_slices = np.repeat(
- np.array([[0, self.block_size]], dtype=np.int32), total_block_len, axis=0
- )
- cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32)
- np.cumsum(block_lens, out=cu_block_lens[1:])
- for req_idx in range(num_reqs):
- slot_mapping_slices[cu_block_lens[req_idx]][0] = (
- slices_start[req_idx] % self.block_size
- )
- slot_mapping_slices[cu_block_lens[req_idx + 1] - 1][1] = (
- slices_end[req_idx] - 1
- ) % self.block_size + 1
- slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0]
- cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32)
- np.cumsum(slice_lens, out=cu_slices_lens[1:])
- kv_cache_start_indices = slot_mapping_slices[:, 0] + (
- block_numbers * self.block_size
- )
- new_kv_start_indices = cu_slices_lens[:-1]
- slot_mapping_metadata = np.stack(
- [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1
- )
- return slot_mapping_metadata
-
- def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int):
- assert scheduler_output.total_num_scheduled_tokens > 0
- num_reqs = self.input_batch.num_reqs
- assert num_reqs > 0
- assert start_index < num_reqs
-
- # Get the number of scheduled tokens for each request.
- use_max_model_len = self.most_model_len is None
- num_scheduled_tokens_per_req = []
- max_num_scheduled_tokens_all_reqs = 0
- end_index = start_index
-
- # Use either most_model_len or max_model_len depending on request size.
- for i in range(start_index, num_reqs):
- req_id = self.input_batch.req_ids[i]
- assert req_id is not None
- num_tokens = scheduler_output.num_scheduled_tokens[req_id]
- if (
- not use_max_model_len
- and self.most_model_len is not None
- and num_tokens > self.most_model_len
- ):
- use_max_model_len = True
- num_scheduled_tokens_per_req.append(num_tokens)
- if use_max_model_len:
- if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len:
- num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
- : self.num_reqs_max_model_len
- ]
- end_index = start_index + self.num_reqs_max_model_len
- else:
- end_index = num_reqs
- else:
- assert self.num_reqs_most_model_len is not None
- if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len:
- num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
- : self.num_reqs_most_model_len
- ]
- end_index = start_index + self.num_reqs_most_model_len
- else:
- end_index = num_reqs
- max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req)
- num_scheduled_tokens_per_req = np.array(
- num_scheduled_tokens_per_req, dtype=np.int32
- )
- total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req)
- assert max_num_scheduled_tokens_all_reqs > 0
-
- num_reqs = len(num_scheduled_tokens_per_req)
-
- # Get request indices.
- # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
- # For each scheduled token, what are the corresponding req index.
- req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req)
-
- # Get batched arange.
- # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
- # For each scheduled token, what is its position in corresponding req.
- arange = np.concatenate(
- [self.arange_np[:n] for n in num_scheduled_tokens_per_req]
- )
-
- # Get positions.
- positions_np = self.positions_np[:total_num_scheduled_tokens]
- np.add(
- self.input_batch.num_computed_tokens_cpu[req_indices],
- arange,
- out=positions_np,
- )
-
- # Get token indices.
- # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
- # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
- # where M is the max_model_len.
- token_indices = (
- positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
- )
-
- # NOTE(woosuk): We use torch.index_select instead of np.take here
- # because torch.index_select is much faster than np.take for large
- # tensors.
- torch.index_select(
- self.input_batch.token_ids_cpu_tensor.flatten(),
- 0,
- torch.from_numpy(token_indices),
- out=self.input_ids_cpu[:total_num_scheduled_tokens],
- )
-
- # Prepare the attention metadata.
- self.query_start_loc_np[0] = 0
- np.cumsum(
- num_scheduled_tokens_per_req, out=self.query_start_loc_np[1 : num_reqs + 1]
- )
- self.query_start_loc_np[num_reqs + 1 :] = 1
-
- self.seq_lens_np[:num_reqs] = (
- self.input_batch.num_computed_tokens_cpu[:num_reqs]
- + num_scheduled_tokens_per_req
- )
-
- # Do the padding and copy the tensors to the TPU.
- padded_total_num_scheduled_tokens = _get_padded_token_len(
- self.num_tokens_paddings, total_num_scheduled_tokens
- )
- # Zero out to avoid spurious values from prev iteration (last cp chunk)
- self.input_ids_cpu[
- total_num_scheduled_tokens:padded_total_num_scheduled_tokens
- ] = 0
- self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to(
- self.device
- )
- self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to(
- self.device
- )
- if use_max_model_len:
- block_tables = self.block_table_cpu[
- : self.num_reqs_max_model_len, : self.max_num_blocks_per_req
- ]
- block_tables[:num_reqs, : self.max_num_blocks_per_req] = (
- self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]
- )
- query_start_loc = self.query_start_loc_cpu[
- : self.num_reqs_max_model_len + 1
- ].to(self.device)
- seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device)
- else:
- assert self.num_reqs_most_model_len is not None
- block_tables = self.block_table_cpu[
- : self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req
- ]
- block_tables[:num_reqs, : self.num_blocks_per_most_len_req] = (
- self.input_batch.block_table[0].get_cpu_tensor()[
- :num_reqs, : self.num_blocks_per_most_len_req
- ]
- )
- query_start_loc = self.query_start_loc_cpu[
- : self.num_reqs_most_model_len + 1
- ].to(self.device)
- seq_lens = self.seq_lens_cpu[: self.num_reqs_most_model_len].to(self.device)
- block_tables = block_tables.to(self.device)
-
- # Calculate the slot mapping
- slot_mapping_metadata = self._get_slot_mapping_metadata(
- num_reqs, num_scheduled_tokens_per_req
- )
- num_kv_update_slices = slot_mapping_metadata.shape[0]
- padded_num_slices = _get_padded_num_kv_cache_update_slices(
- padded_total_num_scheduled_tokens, self.max_num_reqs, self.block_size
- )
- slot_mapping_metadata = np.pad(
- slot_mapping_metadata,
- [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
- constant_values=0,
- )
- slot_mapping_metadata = np.transpose(slot_mapping_metadata)
- slot_mapping_metadata = torch.tensor(slot_mapping_metadata, device=self.device)
-
- if self.lora_config is not None:
- # We need to respect padding when activating LoRA adapters
- padded_num_scheduled_tokens_per_req = np.copy(
- num_scheduled_tokens_per_req
- ) # Copying to avoid accidental state corruption bugs
- padded_num_scheduled_tokens_per_req[-1] += (
- padded_total_num_scheduled_tokens - total_num_scheduled_tokens
- )
-
- self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req)
-
- attn_metadata = PallasMetadata(
- slot_mapping=slot_mapping_metadata,
- block_tables=block_tables,
- context_lens=seq_lens,
- query_start_loc=query_start_loc,
- num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device),
- num_kv_update_slices=torch.tensor(
- [num_kv_update_slices], dtype=torch.int32, device=self.device
- ),
- num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block,
- )
- # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
- # request in the batch. While we should not sample any token from this
- # partial request, we do so for simplicity. We will ignore the sampled
- # token from the partial request.
- # TODO: Support prompt logprobs.
- padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
- num_reqs, self.max_num_reqs
- )
- # Indices at which we sample (positions of last token in the sequence).
- # Padded to avoid recompiling when `num_reqs` varies.
- logits_indices = self.query_start_loc_cpu[1 : padded_num_reqs + 1] - 1
- logits_indices = logits_indices.to(self.device)
-
- if self.lora_config is not None:
- # We need to respect padding when activating LoRA adapters
- padded_num_scheduled_tokens_per_req = np.copy(
- num_scheduled_tokens_per_req
- ) # Copying to avoid accidental state corruption bugs
- padded_num_scheduled_tokens_per_req[-1] += (
- padded_total_num_scheduled_tokens - total_num_scheduled_tokens
- )
-
- self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req)
-
- layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys()
- per_layer_attn_metadata = {
- layer_name: attn_metadata for layer_name in layer_names
- }
- return (
- per_layer_attn_metadata,
- logits_indices,
- padded_num_reqs,
- num_reqs,
- end_index,
- )
-
- def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
- scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
- if not scheduled_encoder_inputs:
- return
-
- # Batch the multi-modal inputs.
- mm_kwargs = list[MultiModalKwargsItem]()
- # List of tuple (mm_hash, pos_info)
- mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
- for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
- req_state = self.requests[req_id]
-
- for mm_input_id in encoder_input_ids:
- mm_feature = req_state.mm_features[mm_input_id]
- if mm_feature.data is None:
- continue
- mm_hash = mm_feature.identifier
- mm_kwargs.append(mm_feature.data)
- mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
-
- # Batch mm inputs as much as we can: if a request in the batch has
- # multiple modalities or a different modality than the previous one,
- # we process it separately to preserve item order.
- # FIXME(ywang96): This is a hacky way to deal with multiple modalities
- # in the same batch while still being able to benefit from batching
- # multimodal inputs. The proper solution should be reordering the
- # encoder outputs.
- model = cast(SupportsMultiModal, self.model)
- encoder_outputs = []
- for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
- mm_kwargs,
- device=self.device,
- pin_memory=self.pin_memory,
- ):
- # Run the encoder.
- # `curr_group_outputs` is either of the following:
- # 1. A tensor of shape (num_items, feature_size, hidden_size)
- # in case feature_size is fixed across all multimodal items.
- # 2. A list or tuple (length: num_items) of tensors, each of shape
- # (feature_size, hidden_size) in case the feature size is dynamic
- # depending on the input multimodal items.
- torch_xla.sync(wait=False)
- curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
- torch_xla.sync(wait=False)
-
- sanity_check_mm_encoder_outputs(
- curr_group_outputs,
- expected_num_items=num_items,
- )
-
- if isinstance(curr_group_outputs, torch.Tensor):
- encoder_outputs.append(curr_group_outputs)
- else:
- assert isinstance(curr_group_outputs, (list, tuple))
- for output in curr_group_outputs:
- encoder_outputs.append(output)
-
- # Cache the encoder outputs.
- # NOTE (NickLucche) here we diverge from logic in other runners, as we
- # assume to only have whole mm items to process. Hence we avoid the
- # intrinsic dynamism that `scatter_mm_placeholders` introduces.
- for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
- assert pos_info.is_embed is None, (
- "Expected all positions to be contiguous and embeddings."
- )
- self.encoder_cache[mm_hash] = output
-
- def _gather_mm_embeddings(
- self,
- scheduler_output: "SchedulerOutput",
- ) -> tuple[list[torch.Tensor], torch.Tensor]:
- total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
- padded_total_num_scheduled_tokens = _get_padded_token_len(
- self.num_tokens_paddings, total_num_scheduled_tokens
- )
-
- is_mm_embed = self.is_mm_embed_cpu
- is_mm_embed[:padded_total_num_scheduled_tokens] = False
- mm_embeds = list[torch.Tensor]()
- req_start_idx = 0
-
- for req_id in self.input_batch.req_ids:
- num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
- req_state = self.requests[req_id]
- num_computed_tokens = req_state.num_computed_tokens
-
- # TODO unroll loop and assume/enforce --disable_chunked_mm_input
- # NOTE (NickLucche) here we diverge from logic in other runners, as
- # we assume to only have whole mm items to process. Hence we avoid
- # the intrinsic dynamism that `gather_mm_placeholders` introduces.
- for mm_feature in req_state.mm_features:
- pos_info = mm_feature.mm_position
- start_pos = pos_info.offset
- num_encoder_tokens = pos_info.length
-
- # The encoder output is needed if the two ranges overlap:
- # [num_computed_tokens,
- # num_computed_tokens + num_scheduled_tokens) and
- # [start_pos, start_pos + num_encoder_tokens)
- if start_pos >= num_computed_tokens + num_scheduled_tokens:
- # The encoder output is not needed in this step.
- break
- if start_pos + num_encoder_tokens <= num_computed_tokens:
- # The encoder output is already processed and stored
- # in the decoder's KV cache.
- continue
-
- start_idx = max(num_computed_tokens - start_pos, 0)
- end_idx = min(
- num_computed_tokens - start_pos + num_scheduled_tokens,
- num_encoder_tokens,
- )
- assert start_idx < end_idx
-
- mm_hash = mm_feature.identifier
- encoder_output = self.encoder_cache.get(mm_hash, None)
- assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
-
- assert pos_info.is_embed is None, (
- "Expected all positions to be contiguous and embeddings."
- )
-
- req_start_pos = req_start_idx + start_pos - num_computed_tokens
- is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = True
-
- # Only whole mm items are processed
- mm_embeds.append(encoder_output)
-
- req_start_idx += num_scheduled_tokens
-
- is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens].to(self.device)
-
- return mm_embeds, is_mm_embed
-
- def _get_model_inputs(
- self,
- input_ids: torch.Tensor,
- mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
- ):
- if self.supports_mm_inputs:
- mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
-
- # NOTE(woosuk): To unify token ids and soft tokens (vision
- # embeddings), we always use embeddings (rather than token ids)
- # as input to the multimodal model, even when the input is text.
- inputs_embeds = self.model.embed_input_ids(
- input_ids,
- multimodal_embeddings=mm_embeds,
- is_multimodal=is_mm_embed,
- )
-
- return None, inputs_embeds
- else:
- # For text-only models, we use token ids as input.
- # While it is possible to use embeddings as input just like the
- # multimodal models, it is not desirable for performance since
- # then the embedding layer is not included in the CUDA graph.
- return input_ids, None
-
- @torch.no_grad()
- def execute_model(
- self,
- scheduler_output: "SchedulerOutput",
- intermediate_tensors: IntermediateTensors | None = None,
- ) -> ModelRunnerOutput | None:
- if self.scheduler_output is not None:
- raise RuntimeError(
- "State error: sample_tokens() must be called "
- "after execute_model() returns None."
- )
- # Update cached state
- self._update_states(scheduler_output)
- if not scheduler_output.total_num_scheduled_tokens:
- if not has_kv_transfer_group():
- # Return empty ModelRunnerOutput if there's no work to do.
- return EMPTY_MODEL_RUNNER_OUTPUT
-
- return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
-
- mm_embed_inputs = None
- if self.supports_mm_inputs:
- # Run the multimodal encoder if any.
- self._execute_mm_encoder(scheduler_output)
- mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
-
- torch_xla.sync(wait=False)
-
- self.scheduler_output = scheduler_output
- self.mm_embed_inputs = mm_embed_inputs
- return None
-
- @torch.no_grad()
- def sample_tokens(
- self, grammar_output: "GrammarOutput | None"
- ) -> ModelRunnerOutput:
- if self.scheduler_output is None:
- # Nothing to do (PP non-final rank case), output isn't used.
- return None # type: ignore[return-value]
- scheduler_output = self.scheduler_output
- mm_embed_inputs = self.mm_embed_inputs
- self.scheduler_output = None
- self.mm_embed_inputs = None
-
- # Prepare inputs, the requests might be split into multiple
- # executions, combine the result of each execution.
- start_index = 0
- combined_selected_tokens: list[torch.Tensor] = []
- combined_logprobs: list[LogprobsLists] = []
-
- # NOTE: setup current batch's metadata for kv connector.
- # Currently, only verified with NixlConnector
- with set_forward_context(None, self.vllm_config):
- self.maybe_setup_kv_connector(scheduler_output)
-
- while start_index < self.input_batch.num_reqs:
- attn_metadata, logits_indices, padded_num_reqs, num_reqs, end_index = (
- self._prepare_inputs(scheduler_output, start_index)
- )
- input_ids, inputs_embeds = self._get_model_inputs(
- self.input_ids, mm_embed_inputs
- )
- torch_xla.sync(wait=False)
- # Run the decoder
- with set_forward_context(
- attn_metadata,
- self.vllm_config,
- num_tokens=scheduler_output.total_num_scheduled_tokens,
- ):
- hidden_states = self.model(
- input_ids=input_ids,
- positions=self.position_ids,
- inputs_embeds=inputs_embeds,
- )
- hidden_states = self.select_hidden_states(hidden_states, logits_indices)
- logits = self.compute_logits(hidden_states)
- tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
- self.input_batch, padded_num_reqs, self.device
- )
- if grammar_output is not None:
- require_struct_decoding, grammar_bitmask_padded, arange = (
- self.prepare_structured_decoding_input(logits, grammar_output)
- )
- logits = self.structured_decode(
- require_struct_decoding, grammar_bitmask_padded, logits, arange
- )
- selected_token_ids = self.sample_from_logits_func(
- logits, tpu_sampling_metadata
- )
- # NOTE (NickLucche) Use the original logits (before any penalties or
- # temperature scaling) for the top-k logprobs. We can't enforce it
- # due to recompilations outside torch.compiled code, so just make
- # sure `sample_from_logits` does not modify the logits in-place.
- logprobs = (
- self.gather_logprobs(logits, selected_token_ids)
- if tpu_sampling_metadata.logprobs
- else None
- )
-
- # Remove padding on cpu and keep dynamic op outside of xla graph.
- selected_token_ids = selected_token_ids.cpu()[:num_reqs]
-
- combined_selected_tokens.append(selected_token_ids)
- if tpu_sampling_metadata.logprobs:
- combined_logprobs.append(logprobs.tolists())
-
- start_index = end_index
-
- # NOTE: current kv load and save get h2d/d2h copies involved.
- # Those copies are blocking. Once they become async., kv_save
- # should be called right after each single forward pass,
- # instead of the forwards of the entire input batch.
- self.maybe_wait_for_kv_save()
- finished_sending, finished_recving = self.get_finished_kv_transfers(
- scheduler_output
- )
-
- selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
- if tpu_sampling_metadata.logprobs:
-
- def concat_lists(input_lists):
- result = []
- for input_list in input_lists:
- result.extend(input_list)
- return result
-
- logprobs_lists = LogprobsLists(
- logprob_token_ids=concat_lists(
- [lp.logprob_token_ids for lp in combined_logprobs]
- ),
- logprobs=concat_lists([lp.logprobs for lp in combined_logprobs]),
- sampled_token_ranks=concat_lists(
- [lp.sampled_token_ranks for lp in combined_logprobs]
- ),
- )
- else:
- logprobs_lists = None
-
- # Update the cache state concurrently. Code above will not block until
- # we use `selected_token_ids`. Add mark_step if post-processing changes
- request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
- discard_sampled_tokens_req_indices = []
- num_reqs = self.input_batch.num_reqs
- for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
- assert req_id is not None
- req_state = self.requests[req_id]
- seq_len = (
- req_state.num_computed_tokens
- + scheduler_output.num_scheduled_tokens[req_id]
- )
- if seq_len >= req_state.num_tokens:
- request_seq_lens.append((i, req_state, seq_len))
- else:
- # Ignore the sampled token from the partial request.
- # Rewind the generator state as if the token was not sampled.
- generator = self.input_batch.generators.get(i)
- if generator is not None:
- # This relies on cuda-specific torch-internal impl details
- generator.set_offset(generator.get_offset() - 4)
-
- # Record the index of the request that should not be sampled,
- # so that we could clear the sampled tokens before returning.
- discard_sampled_tokens_req_indices.append(i)
-
- assert all(
- req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]
- ), "req_ids contains None"
- req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
-
- prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
- for req_id in self.input_batch.req_ids[:num_reqs]:
- prompt_logprobs_dict[req_id] = None
-
- max_gen_len = selected_token_ids.shape[-1]
- if max_gen_len == 1:
- valid_sampled_token_ids = selected_token_ids.tolist()
-
- # Mask out the sampled tokens that should not be sampled.
- # TODO: Keep in sync with gpu_model_runner.py, in particular
- # the "else" case here
- for i in discard_sampled_tokens_req_indices:
- valid_sampled_token_ids[i].clear()
-
- # Append sampled tokens
- for i, req_state, seq_len in request_seq_lens:
- token_id = valid_sampled_token_ids[i][0]
- self.input_batch.token_ids_cpu[i, seq_len] = token_id
- req_state.output_token_ids.append(token_id)
- self.input_batch.num_tokens_no_spec[i] += 1
-
- else:
- valid_mask = selected_token_ids != INVALID_TOKEN_ID
- gen_lens = valid_mask.sum(dim=1).tolist()
- valid_sampled_token_ids = [
- seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens)
- ]
- self.input_batch.num_tokens_no_spec[:num_reqs] += gen_lens
- for i, req_state, seq_len in request_seq_lens:
- target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
- self.input_batch.token_ids_cpu[i, target_slice] = (
- valid_sampled_token_ids[i]
- )
- req_state.output_token_ids.extend(valid_sampled_token_ids[i])
-
- kv_connector_output = (
- None
- if (finished_sending is None and finished_recving is None)
- else KVConnectorOutput(
- finished_sending=finished_sending,
- finished_recving=finished_recving,
- )
- )
-
- model_runner_output = ModelRunnerOutput(
- req_ids=req_ids,
- req_id_to_index=self.input_batch.req_id_to_index,
- sampled_token_ids=valid_sampled_token_ids,
- logprobs=logprobs_lists,
- prompt_logprobs_dict=prompt_logprobs_dict,
- pooler_output=[],
- kv_connector_output=kv_connector_output,
- )
-
- # Check there are no new graphs compiled - all the graphs should be
- # captured and compiled during warm up.
- self._verify_num_xla_graphs("execute_model")
-
- return model_runner_output
-
- def update_config(self, overrides: dict[str, Any]) -> None:
- # TODO: TPU config may need extra validation
- # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
- allowed_config_names = {"load_config", "model_config"}
- for config_name, config_overrides in overrides.items():
- assert config_name in allowed_config_names, (
- f"Config `{config_name}` not supported. "
- f"Allowed configs: {allowed_config_names}"
- )
- config = getattr(self, config_name)
- new_config = update_config(config, config_overrides)
- setattr(self, config_name, new_config)
-
- def load_model(self) -> None:
- self.device = self.device_config.device
-
- # NOTE(woosuk): While the executor assigns the TP ranks to the worker
- # process, the ranks can be different from the ranks internally assigned
- # by the xm runtime. Therefore, there is a mismatch in the rank
- # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
- # This is not a problem in linear layers because all-reduce is
- # rank-agnostic. However, it matters for all-gather as the ranks
- # determine the order of concatenating the output tensors.
- # As a workaround, we use the xm's rank assignment only when loading
- # the embedding weights.
- xm_tp_rank = xr.global_ordinal()
- with patch(
- "vllm.model_executor.layers.vocab_parallel_embedding."
- "get_tensor_model_parallel_rank",
- return_value=xm_tp_rank,
- ):
- try:
- if self.use_spmd:
- tpu_loader = TPUModelLoader(
- load_config=self.vllm_config.load_config
- )
- model = tpu_loader.load_model(
- vllm_config=self.vllm_config,
- model_config=self.vllm_config.model_config,
- mesh=self.mesh,
- )
- else:
- model_loader = get_model_loader(self.load_config)
- logger.info("Loading model from scratch...")
- model = model_loader.load_model(
- vllm_config=self.vllm_config, model_config=self.model_config
- )
- except RuntimeError as e:
- raise RuntimeError(
- f"Unable to load model, a likely reason is the model is "
- "too large for the current device's HBM memory. "
- "Consider switching to a smaller model "
- "or sharding the weights on more chips. "
- f"See the detailed error: {e}"
- ) from e
- if self.lora_config is not None:
- model = self.load_lora_model(model, self.vllm_config, self.device)
- replace_set_lora(model)
-
- # Sync all pending XLA execution during model initialization and weight
- # loading.
- torch_xla.sync(wait=False)
- xm.wait_device_ops()
- if not hasattr(self, "model"):
- self.model = model
- self.sampler = TPUSampler()
-
- def reload_weights(self) -> None:
- assert getattr(self, "model", None) is not None, (
- "Cannot reload weights before model is loaded."
- )
- model_loader = get_model_loader(self.load_config)
- logger.info("Reloading weights inplace...")
- model_loader.load_weights(self.model, model_config=self.model_config)
-
- @torch.no_grad()
- def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None:
- if self.supports_mm_inputs:
- input_ids = None
- inputs_embeds = torch.zeros(
- (num_tokens, self.inputs_embeds_size),
- dtype=self.dtype,
- device=self.device,
- )
- else:
- input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device)
- inputs_embeds = None
- actual_num_reqs = min(num_tokens, num_reqs)
- position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device)
- padded_num_slices = _get_padded_num_kv_cache_update_slices(
- num_tokens, self.max_num_reqs, self.block_size
- )
- num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to(
- self.device
- )
- slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to(
- self.device
- )
- block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to(
- self.device
- )
- query_lens = [1] * num_reqs
- query_start_loc = torch.cumsum(
- torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32
- ).to(self.device)
- context_lens = torch.ones((num_reqs,), dtype=torch.int32).to(self.device)
- num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32).to(self.device)
- attn_metadata = PallasMetadata(
- slot_mapping=slot_mapping,
- block_tables=block_tables,
- context_lens=context_lens,
- query_start_loc=query_start_loc,
- num_seqs=num_seqs,
- num_kv_update_slices=num_kv_update_slices,
- num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block,
- )
-
- if self.supports_mm_inputs:
- torch._dynamo.mark_dynamic(inputs_embeds, 0)
- else:
- torch._dynamo.mark_dynamic(input_ids, 0)
- torch._dynamo.mark_dynamic(position_ids, 0)
- torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
- torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1))
- torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
- torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
-
- layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys()
- per_layer_attn_metadata = {
- layer_name: attn_metadata for layer_name in layer_names
- }
-
- with (
- self.maybe_select_dummy_loras(
- self.lora_config, np.array([num_tokens], dtype=np.int32)
- ),
- set_forward_context(per_layer_attn_metadata, self.vllm_config, 0),
- ):
- out = self.model(
- input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds
- )
- self._hidden_states_dtype = out.dtype
-
- def _set_active_loras(
- self, prompt_lora_mapping, token_lora_mapping, lora_requests
- ) -> None:
- torch_xla.sync(wait=False) # Captures input updates
- super()._set_active_loras(
- prompt_lora_mapping, token_lora_mapping, lora_requests
- )
- torch_xla.sync(wait=False) # Captures metadata updates
-
- def _precompile_mm_encoder(self) -> None:
- if not self.supports_mm_inputs:
- return
-
- # Pre-compile MM encoder for all supported data modalities.
- hf_config = self.vllm_config.model_config.hf_config
-
- mm_budget = self.mm_budget
- assert mm_budget is not None
-
- max_items_per_seq_by_modality = mm_budget.max_items_per_batch_by_modality # noqa: E501
-
- for mode, max_items_per_seq in max_items_per_seq_by_modality.items():
- logger.info(
- "Compiling Multimodal %s Encoder with different input shapes.", mode
- )
- start = time.perf_counter()
- # No padding for MM encoder just yet.
- for num_items in range(1, max_items_per_seq + 1):
- logger.info(" -- mode: %s items: %d", mode, num_items)
- batched_dummy_mm_inputs = self._get_mm_dummy_batch(
- mode,
- num_items,
- )
- # Run multimodal encoder.
- torch_xla.sync(wait=False)
- mm_embeds = self.model.embed_multimodal(**batched_dummy_mm_inputs)
- torch_xla.sync(wait=False)
- num_patches = mm_embeds[0].shape[0]
- items_size = num_patches * num_items
-
- # NOTE (NickLucche) pre-compile `embed_input_ids` when mm
- # embeddings are present. We assume `--disable-mm-chunked`,
- # hence only whole items can be scheduled. This implies we just
- # need to compile when `num_items` fit the (padded) `input_ids`
- for num_tokens in self.num_tokens_paddings:
- if num_tokens >= items_size:
- # XLA Workaround: if torch.zeros(..device) is used, XLA
- # compiles a scalar+expansion op, which won't match
- # the graph generated at runtime. CPU->TPU must be used
- placeholders_ids = torch.zeros(
- num_tokens, dtype=torch.int32, device="cpu"
- )
- # Align placeholders and actual num mm_embeddings.
- placeholders_ids[:items_size] = hf_config.image_token_index
-
- placeholders_ids = placeholders_ids.to(self.device)
-
- mm_mask = torch.tensor([False] * num_tokens)
- mm_mask[:items_size] = True
- mm_mask = mm_mask.to(self.device)
- # Assign outputs or the graph will be cut short.
- a, b = self._get_model_inputs(
- placeholders_ids,
- mm_embed_inputs=([mm_embeds], mm_mask),
- )
- assert a is None
- torch_xla.sync(wait=False)
-
- # Pre-compile `embed_input_ids` when mm_embeddings are not
- # present. Chunk is only made of text, no mm_placeholders.
- for num_tokens in self.num_tokens_paddings:
- placeholders_ids = torch.zeros(
- num_tokens, dtype=torch.int32, device="cpu"
- )
- placeholders_ids = placeholders_ids.to(self.device)
- a, b = self._get_model_inputs(
- placeholders_ids,
- mm_embed_inputs=None,
- )
- assert a is None
- torch_xla.sync(wait=False)
-
- xm.wait_device_ops()
- end = time.perf_counter()
- logger.info(
- "Multimodal %s Encoder compilation finished in in %.2f [secs].",
- mode,
- end - start,
- )
-
- def _precompile_backbone(self) -> None:
- logger.info("Compiling the model with different input shapes.")
- start = time.perf_counter()
- for num_tokens in self.num_tokens_paddings:
- logger.info(" -- num_tokens: %d", num_tokens)
- self._dummy_run(
- num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req
- )
- if self.most_model_len is not None:
- self._dummy_run(
- num_tokens,
- self.num_reqs_most_model_len,
- self.num_blocks_per_most_len_req,
- )
- xm.wait_device_ops()
- end = time.perf_counter()
- logger.info("Compilation finished in %.2f [secs].", end - start)
- self._update_num_xla_graphs("model backbone")
-
- def _precompile_select_hidden_states(self) -> None:
- # Compile hidden state selection function for bucketed
- # n_tokens x max_num_reqs. Graph is really small so this is fine.
- logger.info("Compiling select_hidden_states with different input shapes.")
- start = time.perf_counter()
- hsize = self.model_config.get_hidden_size()
- for num_tokens in self.num_tokens_paddings:
- dummy_hidden = torch.zeros(
- (num_tokens, hsize), device=self.device, dtype=self._hidden_states_dtype
- )
- torch._dynamo.mark_dynamic(dummy_hidden, 0)
- for num_reqs in self.num_reqs_paddings:
- indices = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
- torch._dynamo.mark_dynamic(indices, 0)
- self.select_hidden_states(dummy_hidden, indices)
- logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs)
- # Requests can't be more than tokens. But do compile for the
- # next bigger value in case num_tokens uses bucketed padding.
- if num_reqs >= min(num_tokens, self.max_num_reqs):
- break
- xm.wait_device_ops()
- end = time.perf_counter()
- logger.info("Compilation finished in %.2f [secs].", end - start)
- self._update_num_xla_graphs("select_hidden_states")
-
- def _precompile_compute_logits(self) -> None:
- logger.info("Compiling compute_logits with different input shapes.")
- start = time.perf_counter()
- hsize = self.model_config.get_hidden_size()
- for num_reqs in self.num_reqs_paddings:
- dummy_hidden = torch.zeros(
- (num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype
- )
- torch._dynamo.mark_dynamic(dummy_hidden, 0)
- self.compute_logits(dummy_hidden)
- logger.info(" -- num_seqs: %d", num_reqs)
- xm.wait_device_ops()
- end = time.perf_counter()
- logger.info("Compilation finished in %.2f [secs].", end - start)
- self._update_num_xla_graphs("compute_logits")
-
- def _precompile_structured_decoding(self) -> None:
- logger.info("Compiling structured_decoding with different input shapes.")
- start = time.perf_counter()
- for num_reqs in self.num_reqs_paddings:
- dummy_logits = torch.zeros(
- (num_reqs, self.vocab_size),
- device=self.device,
- dtype=self._hidden_states_dtype,
- )
- dummy_require_struct_decoding = self.require_structured_out_cpu[
- :num_reqs
- ].to(self.device)
- dummy_grammar_bitmask = self.grammar_bitmask_cpu[:num_reqs].to(self.device)
- # The first dimension of the above 3 dummy tensors cannot be
- # mark_dynamic because some operations in structured_decode require
- # them to be static.
- arange = self.structured_decode_arange.to(self.device)
- self.structured_decode(
- dummy_require_struct_decoding,
- dummy_grammar_bitmask,
- dummy_logits,
- arange,
- )
- logger.info(" -- num_seqs: %d", num_reqs)
- xm.wait_device_ops()
- end = time.perf_counter()
- logger.info("Compilation finished in %.2f [secs].", end - start)
- self._update_num_xla_graphs("structured_decoding")
-
- def _precompile_sample_from_logits(self) -> None:
- logger.info("Compiling sample_from_logits with different input shapes.")
- start = time.perf_counter()
- for num_reqs in self.num_reqs_paddings:
- dummy_logits = torch.zeros(
- (num_reqs, self.vocab_size),
- device=self.device,
- dtype=self._hidden_states_dtype,
- )
- # The first dimension of dummy_logits cannot be mark_dynamic
- # because some operations in the sampler require it to be static.
- for all_greedy in [False, True]:
- generate_params_if_all_greedy = not all_greedy
- sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
- self.input_batch,
- num_reqs,
- self.device,
- generate_params_if_all_greedy,
- )
- sampling_metadata.all_greedy = all_greedy
- with self.maybe_select_dummy_loras(
- self.lora_config, np.array([num_reqs], dtype=np.int32)
- ):
- self.sample_from_logits_func(dummy_logits, sampling_metadata)
- logger.info(" -- num_seqs: %d", num_reqs)
- xm.wait_device_ops()
- end = time.perf_counter()
- logger.info("Compilation finished in %.2f [secs].", end - start)
- self._update_num_xla_graphs("sample_from_logits")
-
- def _precompile_gather_logprobs(self) -> None:
- logger.info("Compiling gather_logprobs with different input shapes.")
- start = time.perf_counter()
- for num_reqs in self.num_reqs_paddings:
- dummy_logits = torch.zeros(
- (num_reqs, self.vocab_size),
- device=self.device,
- dtype=self._hidden_states_dtype,
- )
- dummy_tokens = torch.zeros((num_reqs, 1), dtype=torch.int64).to(self.device)
- with self.maybe_select_dummy_loras(
- self.lora_config, np.array([num_reqs], dtype=np.int32)
- ):
- self.gather_logprobs(dummy_logits, dummy_tokens)
- logger.info(" -- num_seqs: %d", num_reqs)
- xm.wait_device_ops()
- end = time.perf_counter()
- logger.info("Compilation finished in %.2f [secs].", end - start)
- self._update_num_xla_graphs("gather_logprobs")
-
- def capture_model(self) -> None:
- """
- Precompile all the subgraphs with possible input shapes.
- """
- with self.maybe_setup_dummy_loras(self.lora_config):
- self._precompile_mm_encoder()
- self._precompile_backbone()
- self._precompile_select_hidden_states()
- self._precompile_compute_logits()
- self._precompile_structured_decoding()
- self._precompile_sample_from_logits()
- self._precompile_gather_logprobs()
-
- def profile_run(
- self,
- num_tokens: int,
- ) -> None:
- # Profile with multimodal encoder & encoder cache.
- if self.supports_mm_inputs:
- mm_config = self.model_config.multimodal_config
- if mm_config is not None and mm_config.skip_mm_profiling:
- logger.info(
- "Skipping memory profiling for multimodal encoder and "
- "encoder cache."
- )
- else:
- mm_budget = self.mm_budget
- assert mm_budget is not None
-
- # TODO: handle encoder-decoder models once we support them.
- if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
- # NOTE: Currently model is profiled with a single non-text
- # modality with the max possible input tokens even when
- # it supports multiple.
- dummy_modality = mm_budget.get_modality_with_max_tokens()
- max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[
- dummy_modality
- ]
-
- logger.info(
- "Encoder cache will be initialized with a budget of "
- "%s tokens, and profiled with %s %s items of the "
- "maximum feature size.",
- encoder_budget,
- max_mm_items_per_batch,
- dummy_modality,
- )
-
- # Create dummy batch of multimodal inputs.
- batched_dummy_mm_inputs = self._get_mm_dummy_batch(
- dummy_modality,
- max_mm_items_per_batch,
- )
-
- # Run multimodal encoder.
- # Isolate encoder graph from post-processing to minimize
- # impact of recompilation until it's fixed.
- start = time.perf_counter()
- torch_xla.sync(wait=False)
- dummy_encoder_outputs = self.model.embed_multimodal(
- **batched_dummy_mm_inputs
- )
- torch_xla.sync(wait=False)
- xm.wait_device_ops()
- end = time.perf_counter()
- logger.info(
- "Multimodal Encoder profiling finished in %.2f [secs].",
- end - start,
- )
-
- sanity_check_mm_encoder_outputs(
- dummy_encoder_outputs,
- expected_num_items=max_mm_items_per_batch,
- )
-
- # Cache the dummy encoder outputs.
- self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
-
- # Trigger compilation for general shape.
- self._dummy_run(
- num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req
- )
- if self.most_model_len is not None:
- self._dummy_run(
- num_tokens,
- self.num_reqs_most_model_len,
- self.num_blocks_per_most_len_req,
- )
-
- torch_xla.sync(wait=False)
- xm.wait_device_ops()
- self.encoder_cache.clear()
- gc.collect()
-
- def maybe_setup_cross_layer_kv_sharing(
- self,
- kv_caches: dict[str, torch.Tensor],
- kv_cache_config: KVCacheConfig,
- ) -> None:
- """
- Add layers that re-use KV cache to KV cache group of its target layer.
- Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
- """
- if not self.shared_kv_cache_layers:
- # No cross-layer KV sharing, return
- return
-
- add_kv_sharing_layers_to_kv_cache_groups(
- self.shared_kv_cache_layers,
- kv_cache_config.kv_cache_groups,
- )
-
- for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
- logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
- kv_caches[layer_name] = kv_caches[target_layer_name]
-
- def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
- """
- Initialize KV cache based on `kv_cache_config`.
- Args:
- kv_cache_config: Configuration for the KV cache, including the KV
- cache size of each layer
- """
- if len(kv_cache_config.kv_cache_groups) > 1:
- raise NotImplementedError(
- "Hybrid models with more than one KV cache type are not supported yet."
- )
-
- if (
- kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
- != self.block_size
- ):
- self.input_batch = InputBatch(
- max_num_reqs=self.max_num_reqs,
- max_model_len=self.max_model_len,
- max_num_batched_tokens=self.max_num_tokens,
- device=self.device,
- pin_memory=self.pin_memory,
- vocab_size=self.model_config.get_vocab_size(),
- block_sizes=[
- kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
- ],
- kernel_block_sizes=[
- kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
- ],
- )
- # Verify dtype compatibility between block_table_cpu and input_batch
- assert (
- self.block_table_cpu.dtype
- == self.input_batch.block_table[0].get_cpu_tensor().dtype
- )
-
- kv_cache_sizes = {}
- for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
- assert len(kv_cache_tensor.shared_by) == 1, (
- "KV cache tensor shared by multiple layers is not supported in TPU."
- )
- kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
-
- kv_caches: dict[str, torch.Tensor] = {}
- for kv_cache_group in kv_cache_config.kv_cache_groups:
- kv_cache_spec = kv_cache_group.kv_cache_spec
- for layer_name in kv_cache_group.layer_names:
- tensor_size = kv_cache_sizes[layer_name]
- assert tensor_size % kv_cache_spec.page_size_bytes == 0
- num_blocks = tensor_size // kv_cache_spec.page_size_bytes # noqa
- if isinstance(kv_cache_spec, AttentionSpec):
- if self.use_spmd:
- num_kv_heads = kv_cache_spec.num_kv_heads
- assert self.original_parallel_config is not None
- tp_size = self.original_parallel_config.tensor_parallel_size
- # TODO: Handle kv cache duplication under SPMD mode.
- assert num_kv_heads % tp_size == 0, (
- f"num_kv_heads {num_kv_heads} must be divisible by "
- f"tp_size {tp_size} under SPMD mode"
- )
- kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
- num_blocks,
- kv_cache_spec.block_size,
- kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size,
- )
- dtype = kv_cache_spec.dtype
-
- tpu_kv_cache = torch.zeros(kv_cache_shape, dtype=dtype).to(
- self.device
- )
-
- kv_caches[layer_name] = tpu_kv_cache
- else:
- raise NotImplementedError
-
- # Set up cross-layer KV cache sharing if needed
- self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config)
-
- bind_kv_cache(
- kv_caches,
- self.vllm_config.compilation_config.static_forward_context,
- self.kv_caches,
- )
-
- if self.use_spmd:
- # Shard KV Cache
- for cache in self.kv_caches:
- xs.mark_sharding(cache, self.mesh, (None, "x", None, None))
-
- if has_kv_transfer_group():
- get_kv_transfer_group().register_kv_caches(kv_caches)
- get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)
-
- def reset_dynamo_cache(self):
- # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs`
- # since the compiled model object of the language backbone of a
- # multimodal model needs to be extracted via `get_language_model`.
- if self.model_config.is_multimodal_model:
- compiled_model = self.model.get_language_model().model
- else:
- compiled_model = self.model.model
- if isinstance(compiled_model, TorchCompileWithNoGuardsWrapper):
- logger.info("Clear dynamo cache and cached dynamo bytecode.")
- torch._dynamo.eval_frame.remove_from_cache(
- compiled_model.original_code_object()
- )
- # Reset the wrapper to re-initialize.
- compiled_model.compiled = False
- TorchCompileWithNoGuardsWrapper.__init__(compiled_model)
-
- @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
- def select_hidden_states(self, hidden_states, indices_do_sample):
- return hidden_states[indices_do_sample]
-
- @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
- def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor:
- return self.model.compute_logits(sample_hidden_states)
-
- # TODO: Under SPMD mode, sample_from_logits has correctness issue.
- # Re-enable the torch.compile once the issue is fixed in torchxla.
- # @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
- def sample_from_logits(
- self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata
- ) -> torch.Tensor:
- """
- Sample with xla-friendly function. This function is to be traced
- separately from `forward` for lighter compilation overhead.
- """
- if sampling_metadata.all_greedy:
- out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
- else:
- out_tokens = self.sampler(logits, sampling_metadata).sampled_token_ids
- return out_tokens
-
- @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
- def gather_logprobs(
- self, logits: torch.Tensor, sampled_tokens: torch.Tensor
- ) -> LogprobsTensors:
- """
- Gather the top_logprobs with corresponding tokens. Use a fixed number
- of logprobs as an alternative to having multiple pre-compiled graphs.
- Select the number of logprobs actually demanded by each request on CPU.
- """
- logprobs = self.sampler.compute_logprobs(logits)
- return self.sampler.gather_logprobs(
- logprobs,
- self.model_config.max_logprobs,
- token_ids=sampled_tokens.squeeze(-1),
- )
-
- @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
- def structured_decode(
- self,
- require_struct_decoding: torch.Tensor,
- grammar_bitmask: torch.Tensor,
- logits: torch.Tensor,
- arange: torch.Tensor,
- ) -> torch.Tensor:
- return torch.where(
- require_struct_decoding,
- self.apply_grammar_bitmask(logits, grammar_bitmask, arange),
- logits,
- )
-
- def apply_grammar_bitmask(
- self, logits: torch.Tensor, grammar_bitmask: torch.Tensor, arange: torch.Tensor
- ):
- assert logits.shape[0] == grammar_bitmask.shape[0]
- logits_cloned = logits.clone()
- for i in range(logits.shape[0]):
- unpacked_bitmask = (
- torch.bitwise_right_shift(grammar_bitmask[i][:, None], arange[None, :])
- & 1
- ) == 0
- unpacked_bitmask = unpacked_bitmask.reshape(-1)[: self.vocab_size]
- logits_cloned[i] = logits_cloned[i].masked_fill(
- unpacked_bitmask, -float("inf")
- )
- return logits_cloned
-
- def embed_multimodal(self, *args, **kwargs):
- return self.model.embed_multimodal(*args, **kwargs)
-
- def embed_input_ids(self, *args, **kwargs):
- return self.model.embed_input_ids(*args, **kwargs)
-
- def prepare_structured_decoding_input(
- self, logits: torch.Tensor, grammar_output: "GrammarOutput"
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- grammar_bitmask = grammar_output.grammar_bitmask
- num_reqs, _ = logits.shape
-
- # Reset pre-allocated tensors
- self.grammar_bitmask_cpu.zero_()
- self.require_structured_out_cpu.zero_()
-
- cumulative_mask_idx = 0
- for req_id in grammar_output.structured_output_request_ids:
- if req_id not in self.input_batch.req_id_to_index:
- continue
- batch_index = self.input_batch.req_id_to_index[req_id]
- self.grammar_bitmask_cpu[batch_index] = torch.from_numpy(
- grammar_bitmask[cumulative_mask_idx]
- )
- # It's not guaranteed that all requests in this batch require
- # structured output, so create a bool tensor to represent
- # the requests that need structured output.
- self.require_structured_out_cpu[batch_index] = True
- cumulative_mask_idx += 1
-
- return (
- self.require_structured_out_cpu[:num_reqs].to(logits.device),
- self.grammar_bitmask_cpu[:num_reqs].to(logits.device),
- self.structured_decode_arange.to(logits.device),
- )
-
- def _get_mm_dummy_batch(
- self,
- modality: str,
- max_items_per_batch: int,
- ) -> BatchedTensorInputs:
- """Dummy data for profiling and precompiling multimodal models."""
- assert self.mm_budget is not None
-
- dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
- model_config=self.model_config,
- seq_len=self.max_model_len,
- mm_counts={modality: 1},
- cache=self.mm_budget.cache,
- )
- dummy_mm_data = dummy_decoder_data.multi_modal_data
-
- # Result in the maximum GPU consumption of the model
- dummy_mm_item = dummy_mm_data[modality][0]
- dummy_mm_items = [dummy_mm_item] * max_items_per_batch
-
- return next(
- grouped_mm_kwargs
- for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
- dummy_mm_items,
- device=self.device,
- pin_memory=self.pin_memory,
- )
- )
-
-
-def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
- logger.info("Preparing request paddings:")
- # assert min_req_size is power of 2
- assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
- paddings: list = []
- num = max(MIN_NUM_SEQS, min_req_size)
- while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
- paddings.append(num)
- logger.info(" %d", num)
- num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
- return paddings
-
-
-def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
- res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
- return min(res, upper_limit)
-
-
-def _get_token_paddings(
- min_token_size: int, max_token_size: int, padding_gap: int
-) -> list[int]:
- """Generate a list of padding size, starting from min_token_size,
- ending with a number that can cover max_token_size
-
- If padding_gap == 0 then:
- increase 2X each time (exponential)
- else:
- first increase the size to twice,
- then increase the padding size by padding_gap.
- """
- # assert min_token_size is power of 2
- assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
- paddings = []
- num = min_token_size
-
- if padding_gap == 0:
- logger.info("Using exponential token paddings:")
- while True:
- logger.info(" %d", num)
- paddings.append(num)
- if num >= max_token_size:
- break
- num *= 2
- else:
- logger.info("Using incremental token paddings:")
- while num <= padding_gap:
- logger.info(" %d", num)
- paddings.append(num)
- num *= 2
- num //= 2
- while num < max_token_size:
- num += padding_gap
- logger.info(" %d", num)
- paddings.append(num)
-
- return paddings
-
-
-def _get_padded_token_len(paddings: list[int], x: int) -> int:
- """Return the first element in paddings list greater or equal to x."""
- index = bisect.bisect_left(paddings, x)
- assert index < len(paddings)
- return paddings[index]
-
-
-def _get_padded_num_kv_cache_update_slices(
- num_tokens: int, max_num_reqs: int, page_size: int
-) -> int:
- """Calculates the padded number of KV cache update slices to avoid
- recompilation."""
- # NOTE(chengjiyao): let's say R_i is the token num for i-th request,
- # so it occupies most 2 + R_i // page_size pages. The total maximum
- # possible number of pages needed is sum(2 + R_i // page_size), which
- # is <= 2 * max_num_reqs + sum(R_i) // page_size
- # = 2 * max_num_reqs + num_tokens // page_size
- padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
- padded_num_slices = min(padded_num_slices, num_tokens)
- return padded_num_slices
-
-
-def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
- """Find the optimum number of slices to copy per Pallas program instance.
-
- Increasing the number of slices copied in one instance of the kernel program
- will increase HBM bandwidth utilization via more in-flight DMAs.
-
- However, it will also use more VMEM, and experimentally, we observed
- performance regression at 128 slices on v6e, likely due to running
- out of scalar registers. Thus this function will limit the number of
- slices to 64.
- """
- # The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
- # calculate num_slices_per_block based on 16MB in case any register spills.
- vmem_limit = 16 * 1024 * 1024
- num_slices_per_block = vmem_limit // page_size_bytes
- assert num_slices_per_block > 0, "Number of slices should be positive"
- num_slices_per_block = prev_power_of_2(num_slices_per_block)
- if num_slices_per_block > 64:
- num_slices_per_block = 64
- return num_slices_per_block
-
-
-def replace_set_lora(model):
- def _tpu_set_lora(
- self,
- index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
- embeddings_tensor: torch.Tensor | None,
- ):
- # TODO: The integer index leads to a recompilation, but converting it
- # to a tensor doesn't seem to work anymore. This might be fixed with a
- # later release of torch_xla.
- self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
- torch_xla.sync(wait=False)
-
- def _tpu_reset_lora(self, index: int):
- self._original_reset_lora(index)
- torch_xla.sync(wait=False)
-
- for _, module in model.named_modules():
- if isinstance(module, BaseLayerWithLoRA):
- module._original_set_lora = module.set_lora
- module._original_reset_lora = module.reset_lora
- module.set_lora = _tpu_set_lora.__get__( # type: ignore[method-assign]
- module, module.__class__
- )
- module.reset_lora = _tpu_reset_lora.__get__( # type: ignore[method-assign]
- module, module.__class__
- )
diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py
index 5f6136b178b46..4c73d6c92d391 100644
--- a/vllm/v1/worker/tpu_worker.py
+++ b/vllm/v1/worker/tpu_worker.py
@@ -2,350 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
-import os
-from collections.abc import Callable
-from typing import Any, TypeVar
+from typing import TypeVar
-import torch
-import torch.nn as nn
-
-import vllm.envs as envs
-from vllm.config import VllmConfig, set_current_vllm_config
-from vllm.distributed import (
- ensure_model_parallel_initialized,
- init_distributed_environment,
-)
-from vllm.distributed.kv_transfer import (
- ensure_kv_transfer_initialized,
-)
from vllm.logger import init_logger
-from vllm.lora.request import LoRARequest
-from vllm.model_executor import set_random_seed
-from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_INFERENCE
-from vllm.tasks import SupportedTask
-from vllm.utils.math_utils import cdiv
-from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
-from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
-from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
-from vllm.v1.outputs import ModelRunnerOutput
-from vllm.v1.utils import report_usage_stats
-from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__)
_R = TypeVar("_R")
-if not USE_TPU_INFERENCE:
- logger.info("tpu_inference not found, using vLLM's TPUWorker.")
- import torch_xla.core.xla_model as xm
- import torch_xla.debug.profiler as xp
- import torch_xla.runtime as xr
-
- from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
- from vllm.v1.worker.tpu_model_runner import TPUModelRunner
-
-
-class TPUWorker:
- def __init__(
- self,
- vllm_config: VllmConfig,
- local_rank: int,
- rank: int,
- distributed_init_method: str,
- is_driver_worker: bool = False,
- ):
- self.is_driver_worker = is_driver_worker
- self.vllm_config = vllm_config
- self.model_config = vllm_config.model_config
- self.cache_config = vllm_config.cache_config
- self.lora_config = vllm_config.lora_config
- self.load_config = vllm_config.load_config
- self.parallel_config = vllm_config.parallel_config
- self.use_spmd = envs.VLLM_XLA_USE_SPMD
- self.original_parallel_config = None
- if self.use_spmd:
- # Under SPMD mode, distributed env is initialized as if there is
- # only one worker/device.
- self.original_parallel_config = self.parallel_config
- self.parallel_config.tensor_parallel_size = 1
- self.parallel_config.pipeline_parallel_size = 1
- self.parallel_config.world_size = 1
- self.scheduler_config = vllm_config.scheduler_config
- self.device_config = vllm_config.device_config
- self.speculative_config = vllm_config.speculative_config
- self.observability_config = vllm_config.observability_config
-
- self.parallel_config.rank = rank
- self.local_rank = local_rank
- self.rank = rank
- self.distributed_init_method = distributed_init_method
-
- if self.cache_config.cache_dtype == "auto":
- self.cache_dtype = self.model_config.dtype
- else:
- self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
-
- if self.model_config.trust_remote_code:
- # note: lazy import to avoid importing torch before initializing
- from vllm.utils.import_utils import init_cached_hf_modules
-
- init_cached_hf_modules()
-
- # Delay profiler initialization to the start of the profiling.
- # This is because in vLLM V1, MP runtime is initialized before the
- # TPU Worker is initialized. The profiler server needs to start after
- # MP runtime is initialized.
- self.profiler = None
- self.profile_dir = None
- if vllm_config.profiler_config.profiler == "torch" and self.rank < 1:
- # For TPU, we can only have 1 active profiler session for 1 profiler
- # server. So we only profile on rank0.
- self.profile_dir = vllm_config.profiler_config.torch_profiler_dir
- logger.info(
- "Profiling enabled. Traces will be saved to: %s", self.profile_dir
- )
-
- def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
- self.cache_config.num_gpu_blocks = num_gpu_blocks
- self.cache_config.num_cpu_blocks = num_cpu_blocks
-
- def init_device(self):
- os.environ["PJRT_DEVICE"] = "TPU"
- # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
- # ring, the xla tpu compiler flag
- # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
- # fix this. It will be removed after the bug in XLA compiler is fixed.
- os.environ["LIBTPU_INIT_ARGS"] = (
- os.environ.get("LIBTPU_INIT_ARGS", "")
- + " --xla_tpu_force_1d_allreduce_at_chunk_count=1"
- " --xla_jf_conv_input_fusion=False"
- )
- # --xla_jf_conv_input_fusion=False is used to improve the perf of
- # quantized matmul.
- torch.set_grad_enabled(False)
- torch.set_default_dtype(self.model_config.dtype)
-
- # Initialize the distributed environment.
- self._init_tpu_worker_distributed_environment(
- self.vllm_config, self.rank, self.distributed_init_method, self.local_rank
- )
-
- # Device initialization should happen after initializing
- # the distributed runtime.
- self.device = xm.xla_device()
- self.device_config.device = self.device
-
- # Set random seed.
- set_random_seed(self.model_config.seed)
- xm.set_rng_state(self.model_config.seed, self.device)
-
- # Increase the cache size limit, which is the maximum number of
- # dynamo graphs that can be compiled.
- # TODO (NickLucche) On gsm we compile 80+ graphs.
- # Re-evaluate limit, with MM we may get close to this limit.
- torch._dynamo.config.cache_size_limit = 128
- # Use persistent cache to avoid XLA recompilation.
- # NOTE(woosuk): Set per-rank cache path since different ranks
- # can have slightly different XLA graphs.
- world_size = self.parallel_config.world_size
- rank = xr.global_ordinal()
- # The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
- # Consequently, changes in optimization flags, which affect compilation
- # results, don't change the cache key. This can result in the wrong
- # compilation being used. To prevent this, disabling the XLA compilation
- # cache during development is recommended.We can disable it by
- # `export VLLM_XLA_CACHE_PATH=`
- if envs.VLLM_XLA_CACHE_PATH:
- per_rank_path = os.path.join(
- envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}"
- )
- xr.initialize_cache(per_rank_path, readonly=False)
-
- # Init ModelRunner here, so that we have access to self.device.
- self.model_runner = TPUModelRunner(
- self.vllm_config, self.device, self.original_parallel_config
- )
-
- if rank == 0:
- # If usage stat is enabled, collect relevant info.
- report_usage_stats(self.vllm_config)
-
- def determine_available_memory(self) -> int:
- kv_caches: dict[str, torch.Tensor] = {}
- kv_cache_spec = self.model_runner.get_kv_cache_spec()
- for layer_name, layer_spec in kv_cache_spec.items():
- if isinstance(layer_spec, AttentionSpec):
- dtype = layer_spec.dtype
-
- # Use an empty tensor instead of `None` to force Dynamo to pass
- # it by reference, rather by specializing on the value `None`.
- tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
- kv_caches[layer_name] = tpu_kv_cache
- else:
- raise NotImplementedError(
- f"Unsupported KV cache spec '{type(layer_spec)}'"
- )
-
- runner_kv_caches: list[torch.Tensor] = []
- bind_kv_cache(
- kv_caches,
- self.vllm_config.compilation_config.static_forward_context,
- runner_kv_caches,
- )
-
- # `max_num_tokens >= max_num_batched_tokens` due to padding.
- with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
- self.model_runner.profile_run(self.model_runner.max_num_tokens)
-
- # Synchronize before measuring the memory usage.
- xm.wait_device_ops()
-
- # During the profiling run, the model runs without KV cache. After
- # the profiling run, the model always runs with KV cache. Here we clear
- # the dynamo cache and cached bytecode to ensure the model always has
- # one compiled bytecode. Having one FX graph/cached bytecode per
- # compiled model is required for `support_torch_compile` decorator to
- # skip dynamo guard.
- with set_current_vllm_config(self.vllm_config):
- self.model_runner.reset_dynamo_cache()
-
- # Get the maximum amount of memory used by the model weights and
- # intermediate activations.
- if self.use_spmd:
- # This is a workaround for the TPU SPMD mode. The get_memory_info
- # API doesn't work with SPMD mode in PyTorch/XLA.
- # TODO: use xm.get_memory_info for SPMD once it's supported in
- # PyTorch/XLA.
- import tpu_info
-
- chip_type, _ = tpu_info.device.get_local_chips()
- device_usage = tpu_info.metrics.get_chip_usage(chip_type)
- total_memory_size = device_usage[0].total_memory
- current_mem = device_usage[0].memory_usage
- else:
- m = xm.get_memory_info(self.device)
- total_memory_size = m["bytes_limit"]
- current_mem = m["bytes_used"]
- # Ideally we would use profiled = m["peak_bytes_used"] to
- # get weights + activations. But there is memory used during
- # compilation / weight loading that impacts the peak and
- # there is no way to reset peak memory in XLA, So we
- # use the heuristic of 2% of weights.
- profiled = current_mem * 1.02
-
- # Calculate the TPU KV cache size based on profiling.
- usable_memory_size = int(
- total_memory_size * self.cache_config.gpu_memory_utilization
- )
- tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
- head_size = self.model_config.get_head_size()
- if head_size > 0:
- padded_head_size = (
- cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
- )
- if padded_head_size != head_size:
- logger.warning_once("head size is padded to %d", padded_head_size)
- # We adjust the usable memory size for the KV cache to prevent OOM
- # errors, even after padding the head_size.
- tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
- return int(tpu_kv_cache_bytes)
-
- def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
- return self.model_runner.sample_tokens(grammar_output)
-
- def execute_model(
- self, scheduler_output: "SchedulerOutput"
- ) -> ModelRunnerOutput | None:
- return self.model_runner.execute_model(scheduler_output)
-
- def profile(self, is_start: bool = True):
- if self.rank < 1:
- if self.profile_dir is None:
- raise RuntimeError("Profiler is not enabled.")
- if is_start:
- if self.profiler is None:
- self.profiler = xp.start_server(9012)
- xp.start_trace(self.profile_dir)
- else:
- xp.stop_trace()
-
- def add_lora(self, lora_request: LoRARequest) -> bool:
- return self.model_runner.add_lora(lora_request)
-
- def load_model(self) -> None:
- self.model_runner.load_model()
-
- def update_config(self, overrides: dict[str, Any]) -> None:
- self.model_runner.update_config(overrides)
-
- def reload_weights(self) -> None:
- self.model_runner.reload_weights()
-
- def compile_or_warm_up_model(self) -> None:
- if not self.model_config.enforce_eager:
- self.model_runner.capture_model()
-
- # Reset the seed to ensure that the random state is not affected by
- # the model initialization and profiling.
- set_random_seed(self.model_config.seed)
-
- def reset_mm_cache(self) -> None:
- self.model_runner.reset_mm_cache()
-
- def get_model(self) -> nn.Module:
- return self.model_runner.get_model()
-
- def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
- return self.model_runner.get_supported_tasks()
-
- def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
- return self.model_runner.get_kv_cache_spec()
-
- def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
- """Allocate GPU KV cache with the specified kv_cache_config."""
- self.model_runner.initialize_kv_cache(kv_cache_config)
-
- def check_health(self) -> None:
- # worker will always be healthy as long as it's running.
- return
-
- def _init_tpu_worker_distributed_environment(
- self,
- vllm_config: VllmConfig,
- rank: int,
- distributed_init_method: str | None = None,
- local_rank: int = -1,
- ) -> None:
- """Initialize the distributed environment."""
- if self.use_spmd:
- xr.use_spmd()
- # NOTE(woosuk): This is just to initialize the TP group and broadcast
- # the input objects on CPU. The all-reduce and all-gather ops on TPU
- # are invoked by `xm.all_reduce` and `xm.all_gather` which use their
- # own context.
- parallel_config = vllm_config.parallel_config
- init_distributed_environment(
- world_size=parallel_config.world_size,
- rank=rank,
- local_rank=local_rank,
- distributed_init_method=distributed_init_method or "env://",
- backend=current_platform.dist_backend,
- )
- ensure_model_parallel_initialized(
- parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
- )
-
- ensure_kv_transfer_initialized(vllm_config)
-
- def shutdown(self) -> None:
- self.model_runner.ensure_kv_transfer_shutdown()
-
- def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
- """Apply a function on the model inside this worker."""
- return fn(self.get_model())
-
-
+# TODO(weiyulin) Remove this file after adding an official way to use hardware plugin
if USE_TPU_INFERENCE:
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker