mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 22:53:38 +08:00
Merge 9aaed80cc85f65438f859bdd19fe90d6b712be5c into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
737b3079ad
@ -92,7 +92,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
|
||||
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
|
||||
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
|
||||
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
||||
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
|
||||
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
|
||||
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
|
||||
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
|
||||
|
||||
@ -1,139 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
from torch_xla._internal import tpu
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# This file contains tests to ensure that LoRA works correctly on the TPU
|
||||
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
|
||||
# for this. The adapters are:
|
||||
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
|
||||
# from 1 to 4.
|
||||
|
||||
# These adapters are trained using a standard huggingface peft training script,
|
||||
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
|
||||
# 100 training iterations with a training batch size of 100.
|
||||
|
||||
|
||||
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
|
||||
return vllm.LLM(
|
||||
model="Qwen/Qwen2.5-3B-Instruct",
|
||||
max_model_len=256,
|
||||
max_num_seqs=8,
|
||||
tensor_parallel_size=tp,
|
||||
enable_lora=True,
|
||||
max_loras=num_loras,
|
||||
max_lora_rank=8,
|
||||
)
|
||||
|
||||
|
||||
TPU_TENSOR_PARALLEL_SIZES = (
|
||||
[1, tpu.num_available_chips()] if tpu.num_available_chips() > 1 else [1]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
|
||||
def test_single_lora(tp: int):
|
||||
"""
|
||||
This test ensures we can run a single LoRA adapter on the TPU backend.
|
||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
|
||||
will force Qwen2.5-3B-Instruct to claim 1+1=1.
|
||||
"""
|
||||
|
||||
llm = setup_vllm(1, tp)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
lora_request = LoRARequest(
|
||||
"lora_adapter_1",
|
||||
1,
|
||||
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter",
|
||||
)
|
||||
output = (
|
||||
llm.generate(
|
||||
prompt,
|
||||
sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
|
||||
lora_request=lora_request,
|
||||
)[0]
|
||||
.outputs[0]
|
||||
.text
|
||||
)
|
||||
|
||||
answer = output.strip()[0]
|
||||
|
||||
assert answer.isdigit()
|
||||
assert int(answer) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
|
||||
def test_lora_hotswapping(tp: int):
|
||||
"""
|
||||
This test ensures we can run multiple LoRA adapters on the TPU backend, even
|
||||
if we only have space to store 1.
|
||||
|
||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
||||
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
||||
"""
|
||||
|
||||
lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
||||
lora_requests = [
|
||||
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
||||
for i in range(1, 5)
|
||||
]
|
||||
|
||||
llm = setup_vllm(1, tp)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
for i, req in enumerate(lora_requests):
|
||||
output = (
|
||||
llm.generate(
|
||||
prompt,
|
||||
sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
|
||||
lora_request=req,
|
||||
)[0]
|
||||
.outputs[0]
|
||||
.text
|
||||
)
|
||||
answer = output.strip()[0]
|
||||
|
||||
assert answer.isdigit()
|
||||
assert int(answer) == i + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
|
||||
def test_multi_lora(tp: int):
|
||||
"""
|
||||
This test ensures we can run multiple LoRA adapters on the TPU backend, when
|
||||
we have enough space to store all of them.
|
||||
|
||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
||||
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
||||
"""
|
||||
lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
||||
lora_requests = [
|
||||
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
||||
for i in range(1, 5)
|
||||
]
|
||||
|
||||
llm = setup_vllm(4, tp)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
for i, req in enumerate(lora_requests):
|
||||
output = (
|
||||
llm.generate(
|
||||
prompt,
|
||||
sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
|
||||
lora_request=req,
|
||||
)[0]
|
||||
.outputs[0]
|
||||
.text
|
||||
)
|
||||
|
||||
answer = output.strip()[0]
|
||||
|
||||
assert answer.isdigit()
|
||||
assert int(output.strip()[0]) == i + 1
|
||||
@ -1,86 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import glob
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import depyf
|
||||
|
||||
|
||||
def test_tpu_compilation():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with depyf.prepare_debug(temp_dir):
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
"It is only with the heart that one can see rightly;",
|
||||
"The greatest glory in living lies not in never falling,",
|
||||
]
|
||||
answers = [
|
||||
" or, through inaction",
|
||||
" what is essential ",
|
||||
" but in rising ",
|
||||
]
|
||||
|
||||
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
|
||||
N = 1
|
||||
sampling_params = SamplingParams(temperature=0.7, top_p=1.0, n=N, max_tokens=16)
|
||||
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
max_num_batched_tokens=256,
|
||||
max_model_len=256,
|
||||
max_num_seqs=32,
|
||||
enforce_eager=False,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output, answer in zip(outputs, answers):
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
assert generated_text.startswith(answer)
|
||||
|
||||
compiled_codes = sorted(
|
||||
glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py"))
|
||||
)
|
||||
|
||||
for i, compiled_code in enumerate(compiled_codes):
|
||||
print("{} file: {}".format(i + 1, compiled_code))
|
||||
|
||||
# We should only trigger Dynamo compilation 2 times:
|
||||
# 1. Forward pass without kv_caches
|
||||
# 2. Forward pass with kv_caches
|
||||
# Check we have 2 compiled codes
|
||||
assert len(compiled_codes) == 2
|
||||
|
||||
kv_cache_prefix = "kv_cache"
|
||||
attn_prefix = "ragged_paged_attention"
|
||||
|
||||
def extract_compiled_index(s):
|
||||
parts = s.replace(".", "_").split("_")
|
||||
numbers = [int(part) for part in parts if part.isdigit()]
|
||||
return numbers[0]
|
||||
|
||||
# Check all the compilations are as expected. The dump files include the
|
||||
# captured graph for the forward function of the nn.Module.
|
||||
compiled_fns = sorted(
|
||||
glob.glob(os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")),
|
||||
key=lambda s: extract_compiled_index(s),
|
||||
)
|
||||
|
||||
for i, compiled_fn in enumerate(compiled_fns):
|
||||
print("{} file: {}".format(i + 1, compiled_fn))
|
||||
|
||||
# The first compilation should not have any kv_caches
|
||||
with open(compiled_fns[0]) as f:
|
||||
content = f.read()
|
||||
assert kv_cache_prefix not in content
|
||||
|
||||
# The second compilation should have kv_caches and the
|
||||
# ragged_paged_attention
|
||||
with open(compiled_fns[1]) as f:
|
||||
content = f.read()
|
||||
assert kv_cache_prefix in content and attn_prefix in content
|
||||
@ -1,34 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationMode
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
# --enforce-eager on TPU causes graph compilation
|
||||
# this times out default Health Check in the MQLLMEngine,
|
||||
# so we set the timeout here to 30s
|
||||
|
||||
|
||||
def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_RPC_TIMEOUT", "30000")
|
||||
compare_two_settings(
|
||||
"Qwen/Qwen2.5-1.5B-Instruct",
|
||||
arg1=[
|
||||
"--max-model-len=256",
|
||||
"--max-num-seqs=32",
|
||||
"--enforce-eager",
|
||||
f"-O{CompilationMode.DYNAMO_TRACE_ONCE}",
|
||||
],
|
||||
arg2=[
|
||||
"--max-model-len=256",
|
||||
"--max-num-seqs=32",
|
||||
"--enforce-eager",
|
||||
f"-O{CompilationMode.STOCK_TORCH_COMPILE}",
|
||||
],
|
||||
env1={},
|
||||
env2={},
|
||||
)
|
||||
@ -1,88 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the Pallas MOE implementation.
|
||||
|
||||
Run `pytest tests/kernels/moe/test_moe_pallas.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe as pallas_moe
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as torch_moe,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
|
||||
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
|
||||
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_pallas_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
with torch.device(xm.xla_device()):
|
||||
a = torch.randn((m, k), dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), dtype=dtype)
|
||||
|
||||
# TODO: Support ep
|
||||
if ep_size > 1:
|
||||
pytest.skip("No support for ep_size > 1 yet")
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
# Run both implementations
|
||||
torch_output = torch_moe(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
gating_output=score,
|
||||
topk=topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
pallas_output = pallas_moe(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
gating_output=score,
|
||||
topk=topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False,
|
||||
)
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(
|
||||
pallas_output.cpu(),
|
||||
torch_output.cpu(),
|
||||
atol=2e-2,
|
||||
rtol=0,
|
||||
)
|
||||
@ -1,52 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import lm_eval
|
||||
import pytest
|
||||
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
|
||||
|
||||
@dataclass
|
||||
class GSM8KAccuracyTestConfig:
|
||||
model_name: str
|
||||
expected_value: float
|
||||
|
||||
def get_model_args(self) -> str:
|
||||
return f"pretrained={self.model_name},max_model_len=4096,max_num_seqs=32"
|
||||
|
||||
|
||||
# NOTE: Accuracy scores measured on GPUs.
|
||||
ACCURACY_CONFIGS = [
|
||||
GSM8KAccuracyTestConfig(
|
||||
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
|
||||
expected_value=0.76,
|
||||
), # no bias
|
||||
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
|
||||
# so only one of these tests can run in a single call to pytest. As
|
||||
# a follow-up, move this into the LM-EVAL section of the CI.
|
||||
# GSM8KAccuracyTestConfig(
|
||||
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
|
||||
# expected_value=0.66), # bias in QKV layers
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
|
||||
def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=config.get_model_args(),
|
||||
tasks="gsm8k",
|
||||
batch_size="auto",
|
||||
)
|
||||
|
||||
EXPECTED_VALUE = config.expected_value
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (
|
||||
measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
@ -1,177 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A basic correctness check for TPUs
|
||||
|
||||
Run `pytest tests/v1/tpu/test_basic.py`.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from torch_xla._internal import tpu
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tests.conftest import VllmRunner
|
||||
else:
|
||||
VllmRunner = object
|
||||
|
||||
MODELS = [
|
||||
"Qwen/Qwen2.5-1.5B-Instruct",
|
||||
# TODO: Enable this model when fixed.
|
||||
# "Qwen/Qwen1.5-MoE-A2.7B",
|
||||
# TODO: Enable this models with v6e
|
||||
# "Qwen/Qwen2-7B-Instruct",
|
||||
# "meta-llama/Llama-3.1-8B",
|
||||
]
|
||||
|
||||
TENSOR_PARALLEL_SIZES = [1]
|
||||
MAX_NUM_REQS = [16, 1024]
|
||||
|
||||
# TODO: Enable when CI/CD will have a multi-tpu instance
|
||||
# TENSOR_PARALLEL_SIZES = [1, 4]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
|
||||
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
|
||||
def test_basic(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
tensor_parallel_size: int,
|
||||
max_num_seqs: int,
|
||||
) -> None:
|
||||
prompt = (
|
||||
"The next numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(1024))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [prompt]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
# Note: max_num_batched_tokens == 1024 is needed here to
|
||||
# actually test chunked prompt
|
||||
max_num_batched_tokens=1024,
|
||||
max_model_len=8192,
|
||||
gpu_memory_utilization=0.7,
|
||||
max_num_seqs=max_num_seqs,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
output = vllm_outputs[0][1]
|
||||
|
||||
assert "1024" in output or "0, 1" in output
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Temporarily disabled due to timeout")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
|
||||
)
|
||||
@pytest.mark.parametrize("max_tokens", [8])
|
||||
@pytest.mark.parametrize("max_num_seqs", [16])
|
||||
def test_phi3(
|
||||
vllm_runner: type[VllmRunner],
|
||||
max_tokens: int,
|
||||
max_num_seqs: int,
|
||||
) -> None:
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
"It is only with the heart that one can see rightly;",
|
||||
"The greatest glory in living lies not in never falling,",
|
||||
]
|
||||
answers = [
|
||||
" or, by violating privacy",
|
||||
" what is essential is love.",
|
||||
" but in rising every time we fall.",
|
||||
]
|
||||
# test head dim = 96
|
||||
model = "microsoft/Phi-3-mini-128k-instruct"
|
||||
|
||||
with vllm_runner(
|
||||
model, max_num_batched_tokens=256, max_num_seqs=max_num_seqs
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
|
||||
# vllm_outputs is a list of tuples whose first element is the token id
|
||||
# and the second element is the output (including the prompt).
|
||||
for output, answer in zip(vllm_outputs, answers):
|
||||
generated_text = output[1]
|
||||
assert answer in generated_text
|
||||
|
||||
|
||||
TP_SIZE_8 = 8
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
|
||||
@pytest.mark.skipif(
|
||||
tpu.num_available_chips() < TP_SIZE_8,
|
||||
reason=f"This test requires {TP_SIZE_8} TPU chips.",
|
||||
)
|
||||
def test_gemma3_27b_with_text_input_and_tp(
|
||||
vllm_runner: type[VllmRunner],
|
||||
) -> None:
|
||||
model = "google/gemma-3-27b-it"
|
||||
max_tokens = 16
|
||||
tensor_parallel_size = TP_SIZE_8
|
||||
max_num_seqs = 4
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
"It is only with the heart that one can see rightly;",
|
||||
"The greatest glory in living lies not in never falling,",
|
||||
]
|
||||
answers = [
|
||||
" or, through inaction, allow a human being to come to harm.",
|
||||
" what is essential is invisible to the eye.",
|
||||
" but in rising every time we fall.",
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_num_batched_tokens=256,
|
||||
max_num_seqs=max_num_seqs,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
|
||||
# vllm_outputs is a list of tuples whose first element is the token id
|
||||
# and the second element is the output (including the prompt).
|
||||
for output, answer in zip(vllm_outputs, answers):
|
||||
generated_text = output[1]
|
||||
assert answer in generated_text
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
|
||||
)
|
||||
def test_w8a8_quantization(
|
||||
vllm_runner: type[VllmRunner],
|
||||
) -> None:
|
||||
model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
|
||||
max_tokens = 5
|
||||
tensor_parallel_size = 1
|
||||
max_num_seqs = 4
|
||||
|
||||
prompt = (
|
||||
"The next numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(1024))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [prompt]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_num_batched_tokens=64,
|
||||
max_model_len=4096,
|
||||
gpu_memory_utilization=0.7,
|
||||
max_num_seqs=max_num_seqs,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
output = vllm_outputs[0][1]
|
||||
|
||||
assert "1024" in output or "0, 1" in output
|
||||
@ -1,78 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla
|
||||
|
||||
import vllm.v1.attention.backends.pallas # noqa: F401
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
|
||||
@pytest.mark.parametrize("page_size", [32, 33])
|
||||
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
|
||||
@pytest.mark.parametrize("head_dim", [128, 256])
|
||||
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
|
||||
def test_kv_cache_update_kernel(
|
||||
page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int
|
||||
):
|
||||
page_num = 1000
|
||||
padded_num_tokens = 128
|
||||
kv_cache_cpu = torch.zeros(
|
||||
(page_num * page_size, combined_kv_head_num, head_dim),
|
||||
dtype=torch.bfloat16,
|
||||
device="cpu",
|
||||
)
|
||||
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
|
||||
new_kv_cpu = torch.randn(
|
||||
(padded_num_tokens, combined_kv_head_num, head_dim),
|
||||
dtype=torch.bfloat16,
|
||||
device="cpu",
|
||||
)
|
||||
new_kv_xla = new_kv_cpu.to(torch_xla.device())
|
||||
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], dtype=np.int32)
|
||||
num_kv_update_slices = len(slice_lens)
|
||||
kv_cache_start_indices = np.array(
|
||||
[
|
||||
page_size * 2 - 7,
|
||||
page_size * 2,
|
||||
page_size * 3,
|
||||
page_size * 4 + 6,
|
||||
page_size * 5 + 7,
|
||||
page_size * 6 + 8,
|
||||
page_size * 15 + 3,
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
new_kv_cache_indices = np.concatenate(
|
||||
[np.array([0], dtype=np.int32), np.cumsum(slice_lens[:-1])]
|
||||
)
|
||||
slot_mapping = np.stack(
|
||||
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1
|
||||
)
|
||||
slot_mapping = np.transpose(slot_mapping)
|
||||
slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", dtype=torch.int32)
|
||||
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
|
||||
num_kv_update_slices_xla = torch.tensor(
|
||||
[num_kv_update_slices], device=torch_xla.device(), dtype=torch.int32
|
||||
)
|
||||
torch_xla.sync()
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
|
||||
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
|
||||
new_kv_xla,
|
||||
slot_mapping_xla,
|
||||
kv_cache_xla,
|
||||
num_kv_update_slices_xla,
|
||||
page_size,
|
||||
num_slices_per_block,
|
||||
)
|
||||
kv_cache_xla.copy_(new_kv_cache_xla)
|
||||
torch_xla.sync()
|
||||
|
||||
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, slice_lens):
|
||||
kv_cache_cpu[ci : ci + sl, :, :] = new_kv_cpu[ni : ni + sl, :, :]
|
||||
|
||||
assert torch.allclose(kv_cache_xla.cpu(), kv_cache_cpu, atol=1e-4, rtol=1e-4)
|
||||
@ -1,94 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test:
|
||||
|
||||
* Tests for MMEncoderAttention layer
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla
|
||||
import torch_xla.core
|
||||
import torch_xla.core.xla_model
|
||||
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.attention.selector import _cached_get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching."""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
def ref_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Native implementation of scaled dot product attention without mask:
|
||||
- query, key, value: [batch_size, seq_len, num_heads, head_size]
|
||||
- attn_mask: [batch_size, seq_len, seq_len]
|
||||
"""
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.matmul(attn_weights, value).transpose(1, 2)
|
||||
return out
|
||||
|
||||
|
||||
BATCH_SIZES = [1, 16]
|
||||
SEQ_LENS = [1]
|
||||
NUM_HEADS = [1, 16]
|
||||
NUM_KV_HEADS = [1]
|
||||
HEAD_SIZES = [64, 80]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("device", [torch_xla.core.xla_model.xla_device()])
|
||||
def test_mha_attn_forward(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
device: str,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
# These are expected to be f32
|
||||
q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device)
|
||||
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
|
||||
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
|
||||
scale = 1.0 / head_size**0.5
|
||||
attn = MMEncoderAttention(
|
||||
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
||||
)
|
||||
output = attn(q, k, v)
|
||||
|
||||
assert num_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_heads // num_kv_heads
|
||||
|
||||
q = q.reshape(batch_size, seq_len, num_heads, head_size)
|
||||
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
||||
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
||||
if num_queries_per_kv > 1:
|
||||
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
|
||||
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
|
||||
|
||||
ref_output = ref_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
).reshape(batch_size, seq_len, num_heads * head_size)
|
||||
# torch_xla flash_attn kernel is less accurate but much faster
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-3)
|
||||
@ -1,76 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.utils import encode_image_url
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def url_encoded_image(local_asset_server) -> dict[str, str]:
|
||||
return {
|
||||
image_asset: encode_image_url(local_asset_server.get_image_asset(image_asset))
|
||||
for image_asset in TEST_IMAGE_ASSETS
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
|
||||
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
|
||||
async def test_basic_vision(model_name: str, url_encoded_image: dict[str, str]):
|
||||
pytest.skip("Skip this test until it's fixed.")
|
||||
|
||||
def whats_in_this_image_msg(url):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": url}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
server_args = [
|
||||
"--max-model-len",
|
||||
"1024",
|
||||
"--max-num-seqs",
|
||||
"16",
|
||||
"--gpu-memory-utilization",
|
||||
"0.95",
|
||||
"--trust-remote-code",
|
||||
"--max-num-batched-tokens",
|
||||
"576",
|
||||
# NOTE: max-num-batched-tokens>=mm_item_size
|
||||
"--disable_chunked_mm_input",
|
||||
]
|
||||
|
||||
# Server will pre-compile on first startup (takes a long time).
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, max_wait_seconds=600
|
||||
) as remote_server:
|
||||
client: openai.AsyncOpenAI = remote_server.get_async_client()
|
||||
|
||||
# Other requests now should be much faster
|
||||
for image_url in TEST_IMAGE_ASSETS:
|
||||
image_url = url_encoded_image[image_url]
|
||||
chat_completion_from_url = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=whats_in_this_image_msg(image_url),
|
||||
max_completion_tokens=24,
|
||||
temperature=0.0,
|
||||
)
|
||||
result = chat_completion_from_url
|
||||
assert result
|
||||
choice = result.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
|
||||
message = choice.message
|
||||
message = result.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 10
|
||||
assert message.role == "assistant"
|
||||
@ -1,100 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackendImpl, PallasMetadata
|
||||
|
||||
|
||||
def test_ragged_paged_attention():
|
||||
# We verify that the kernel inputs such as sliding_window, etc. are passed
|
||||
# in from the model correctly.
|
||||
# The correctness of the paged attention kernel is tested in the kernel
|
||||
# library.
|
||||
num_heads = 4
|
||||
head_size = 128
|
||||
scale = 1.0
|
||||
num_kv_heads = 4
|
||||
sliding_window = 128
|
||||
logits_soft_cap = 50.0
|
||||
attn_impl = PallasAttentionBackendImpl(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=sliding_window,
|
||||
kv_cache_dtype="auto",
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
attn_type=AttentionType.DECODER,
|
||||
)
|
||||
|
||||
class FakeAttentionLayer:
|
||||
_q_scale_float: float
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
|
||||
layer = FakeAttentionLayer()
|
||||
layer._q_scale_float = 1.0
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
|
||||
num_tokens = 16
|
||||
num_blocks = 1024
|
||||
block_size = 16
|
||||
query = torch.zeros(num_tokens, num_heads * head_size)
|
||||
key = torch.zeros(num_tokens, num_kv_heads * head_size)
|
||||
value = torch.zeros(num_tokens, num_kv_heads * head_size)
|
||||
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
|
||||
max_num_reqs = 8
|
||||
max_num_blocks_per_req = 8
|
||||
num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32)
|
||||
block_tables = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req), dtype=torch.int32
|
||||
)
|
||||
context_lens = torch.ones((max_num_reqs,), dtype=torch.int32)
|
||||
query_lens = [1] * max_num_reqs
|
||||
query_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32
|
||||
)
|
||||
num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32)
|
||||
attn_metadata = PallasMetadata(
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
num_seqs=num_seqs,
|
||||
num_kv_update_slices=num_kv_update_slices,
|
||||
num_slices_per_kv_cache_update_block=8,
|
||||
)
|
||||
|
||||
with patch("torch.ops.xla.ragged_paged_attention") as mock_ragged_paged_attention:
|
||||
attn_impl.forward(
|
||||
layer=layer,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
mock_ragged_paged_attention.assert_called_once_with(
|
||||
ANY, # query
|
||||
ANY, # kv_cache
|
||||
ANY, # context_lens
|
||||
ANY, # block_tables
|
||||
ANY, # query_start_loc
|
||||
ANY, # num_seqs
|
||||
num_kv_pages_per_block=None,
|
||||
num_queries_per_block=None,
|
||||
vmem_limit_bytes=None,
|
||||
use_kernel=True,
|
||||
sm_scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=logits_soft_cap,
|
||||
k_scale=1.0,
|
||||
v_scale=1.0,
|
||||
)
|
||||
@ -1,150 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A basic performance regression test for TPUs
|
||||
|
||||
Run `pytest tests/v1/tpu/test_perf.py`.
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tests.conftest import VllmRunner
|
||||
else:
|
||||
VllmRunner = object
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestParams:
|
||||
model: str
|
||||
num_prompts: int
|
||||
prefix_len: int
|
||||
decode_len: int
|
||||
expected_avg_time: float
|
||||
err_tol: float
|
||||
|
||||
|
||||
TEST_PARAMS = [
|
||||
# TODO: Cannot run a series of tests because:
|
||||
# RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed:
|
||||
# open(/dev/vfio/0): Device or resource busy: Device or resource busy;
|
||||
# Couldn't open iommu group /dev/vfio/0
|
||||
# => Investigate
|
||||
# TestParams(
|
||||
# model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
# num_prompts=1,
|
||||
# prefix_len=10,
|
||||
# decode_len=5,
|
||||
# expected_avg_time=0.03,
|
||||
# err_tol=0.01,
|
||||
# ),
|
||||
# TestParams(
|
||||
# model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
# num_prompts=10,
|
||||
# prefix_len=100,
|
||||
# decode_len=50,
|
||||
# expected_avg_time=0.234,
|
||||
# err_tol=0.020,
|
||||
# ),
|
||||
TestParams(
|
||||
model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
num_prompts=64,
|
||||
prefix_len=500,
|
||||
decode_len=50,
|
||||
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
|
||||
# tpu: v5lite (old vllm CI/CD)
|
||||
# expected_avg_time=1.4,
|
||||
# err_tol=0.30,
|
||||
# (This is the active CI/CD instance)
|
||||
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
|
||||
# tpu: v6e (current vllm CI/CD)
|
||||
expected_avg_time=1.7, # measured with VLLM_XLA_CACHE_PATH=
|
||||
err_tol=0.20,
|
||||
),
|
||||
]
|
||||
|
||||
NUM_WARMUPS = 5
|
||||
NUM_RUNS = 10
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
MAX_NUM_SEQS = 32
|
||||
GPU_UTIL = 0.9
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_tpu(),
|
||||
reason="This is a basic performance test for TPU only",
|
||||
)
|
||||
@pytest.mark.parametrize("params", TEST_PARAMS)
|
||||
def test_perf(
|
||||
vllm_runner: type[VllmRunner],
|
||||
params: TestParams,
|
||||
) -> None:
|
||||
tokenizer = get_tokenizer(
|
||||
params.model, tokenizer_mode="auto", trust_remote_code=True
|
||||
)
|
||||
|
||||
prompts = []
|
||||
for i in range(params.num_prompts):
|
||||
prefix_token_ids = np.random.randint(
|
||||
0, tokenizer.vocab_size, size=params.prefix_len
|
||||
).tolist()
|
||||
prompt = tokenizer.decode(prefix_token_ids)
|
||||
prompts.append(prompt)
|
||||
|
||||
print(
|
||||
"-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format(
|
||||
len(prompts), params.prefix_len, params.decode_len
|
||||
)
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=params.decode_len, temperature=1.0, min_p=0.0
|
||||
)
|
||||
|
||||
with vllm_runner(
|
||||
params.model,
|
||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
gpu_memory_utilization=GPU_UTIL,
|
||||
enforce_eager=False,
|
||||
tensor_parallel_size=1,
|
||||
) as vllm_model:
|
||||
print(" -- Warmup / Compile")
|
||||
for i in range(NUM_WARMUPS):
|
||||
_ = vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
print(" -- Benchmarking... ")
|
||||
times = []
|
||||
for i in range(NUM_RUNS):
|
||||
start_time = time.time()
|
||||
_ = vllm_model.generate(prompts, sampling_params)
|
||||
times.append(time.time() - start_time)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
|
||||
print(" -- avg_time = {}".format(avg_time))
|
||||
print(
|
||||
" -- expected_avg_time = {} with err_tol = {}".format(
|
||||
params.expected_avg_time, params.err_tol
|
||||
)
|
||||
)
|
||||
diff = avg_time - params.expected_avg_time
|
||||
ok = diff < params.err_tol
|
||||
if diff < -params.err_tol:
|
||||
print(
|
||||
" !! WARNING !! Performance has improved by {}, "
|
||||
"it may be necessary to fine-tune the "
|
||||
"expected_avg_time = {}".format(-diff, params.expected_avg_time)
|
||||
)
|
||||
|
||||
assert ok, " !! ERROR !! Regression detected"
|
||||
@ -1,105 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
|
||||
def test_sampler_different(model_name: str):
|
||||
"""
|
||||
Test significantly different sampling params to assert the model produces
|
||||
different results.
|
||||
"""
|
||||
llm = LLM(
|
||||
model_name,
|
||||
enforce_eager=False,
|
||||
max_num_seqs=1,
|
||||
max_model_len=512,
|
||||
max_num_batched_tokens=256,
|
||||
)
|
||||
prompts = ["Write a short story about a robot that dreams for the first time."]
|
||||
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
|
||||
output = llm.generate(prompts, sampling_params)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
|
||||
output2 = llm.generate(prompts, sampling_params)
|
||||
assert output[0].outputs[0].text != output2[0].outputs[0].text
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# Unsupported `seed` param.
|
||||
sampling_params = SamplingParams(temperature=0.3, seed=42)
|
||||
output2 = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Batch-case with TopK/P
|
||||
for B in [4, 16]:
|
||||
p = prompts * B
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
temperature=0.1,
|
||||
min_p=0.8,
|
||||
max_tokens=64,
|
||||
# Vary number of ks
|
||||
top_k=random.randint(4, 12),
|
||||
top_p=random.random(),
|
||||
)
|
||||
for _ in range(B)
|
||||
]
|
||||
# Make sure first two reqs have the same K/P
|
||||
sampling_params[0] = sampling_params[1]
|
||||
output = llm.generate(p, sampling_params)
|
||||
# There are natural numerical instabilities that make it difficult
|
||||
# to have deterministic results over many tokens, tests the first ~20
|
||||
# tokens match.
|
||||
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
||||
# TODO TPU will appear busy if we fan-out test params here
|
||||
@pytest.mark.parametrize("n_prompts", [1])
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
|
||||
def test_logprobs(model_name: str, n_prompts: int):
|
||||
"""
|
||||
Request top logprobs with different sampling settings and check
|
||||
that results contains the requested number, ordered ascendingly.
|
||||
"""
|
||||
|
||||
def check_num_logprobs(logprobs, expected_num: int):
|
||||
for step in logprobs:
|
||||
prev_logp = 1.0
|
||||
# order by rank
|
||||
sorted_step = dict(sorted(step.items(), key=lambda item: item[1].rank))
|
||||
|
||||
# Can contain the sampled token
|
||||
assert len(step) == expected_num or len(step) == expected_num + 1
|
||||
# Check results are ordered by prob value
|
||||
for rankno, (tid, logp) in enumerate(sorted_step.items()):
|
||||
assert logp.logprob <= prev_logp
|
||||
prev_logp = logp.logprob
|
||||
assert logp.rank == rankno + 1
|
||||
|
||||
llm = LLM(
|
||||
model_name,
|
||||
enforce_eager=False,
|
||||
max_num_seqs=1,
|
||||
max_model_len=128,
|
||||
max_num_batched_tokens=128,
|
||||
)
|
||||
prompts = [
|
||||
"Write a short story about a robot that dreams for the first time."
|
||||
] * n_prompts
|
||||
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64, logprobs=4)
|
||||
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64, logprobs=4)
|
||||
topkp_sampling_params = SamplingParams(
|
||||
temperature=0.4, max_tokens=64, logprobs=4, top_k=12, top_p=0.5
|
||||
)
|
||||
|
||||
for sp in [greedy_sampling_params, regular_sampling_params, topkp_sampling_params]:
|
||||
output = llm.generate(prompts, sp)
|
||||
for o in output:
|
||||
check_num_logprobs(o.outputs[0].logprobs, 4)
|
||||
@ -1,78 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||
|
||||
|
||||
def _setup_environment(model):
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
1,
|
||||
0,
|
||||
local_rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
backend="gloo",
|
||||
)
|
||||
# Under single worker mode, full model is init first and then
|
||||
# partitioned using GSPMD.
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
return vllm_config
|
||||
|
||||
|
||||
MESH = None
|
||||
|
||||
|
||||
def _get_spmd_mesh():
|
||||
global MESH
|
||||
if MESH is None:
|
||||
xr.use_spmd()
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
mesh_shape = (num_devices, 1)
|
||||
device_ids = np.array(range(num_devices))
|
||||
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
|
||||
return MESH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"Qwen/Qwen2-1.5B-Instruct",
|
||||
# Skip large models due to CI runner disk space limitations
|
||||
# "meta-llama/Llama-3.1-8B-Instruct",
|
||||
# "meta-llama/Llama-3.1-70B-Instruct",
|
||||
],
|
||||
)
|
||||
def test_tpu_model_loader(model):
|
||||
# Skip the 70B test if there are less than 8 chips
|
||||
# TODO: Query using torch xla API, the query API is not working
|
||||
# with SPMD now. However, This test is running under SPMD mode.
|
||||
if "70B" in model and xr.global_runtime_device_count() < 8:
|
||||
pytest.skip(
|
||||
"Skipping 70B model if the TPU VM has less than 8 chips to \
|
||||
avoid OOM."
|
||||
)
|
||||
|
||||
vllm_config = _setup_environment(model)
|
||||
loader = TPUModelLoader(load_config=vllm_config.load_config)
|
||||
mesh = _get_spmd_mesh()
|
||||
model = loader.load_model(vllm_config, vllm_config.model_config, mesh)
|
||||
del model
|
||||
gc.collect()
|
||||
@ -1,149 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
VOCAB_SIZE = 128 * 1024
|
||||
TOLERANCE = 1e-6
|
||||
|
||||
|
||||
def test_topk_equivalence_to_native_impl():
|
||||
with torch.device(xm.xla_device()):
|
||||
xm.set_rng_state(seed=33)
|
||||
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
|
||||
|
||||
# Random top-k values between 1 and 10.
|
||||
k = torch.randint(1, 10, (BATCH_SIZE,))
|
||||
|
||||
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
||||
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), VOCAB_SIZE)
|
||||
|
||||
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
|
||||
|
||||
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||
assert torch.allclose(result_native, result_tpu)
|
||||
|
||||
|
||||
def test_topp_result_sums_past_p():
|
||||
with torch.device(xm.xla_device()):
|
||||
xm.set_rng_state(seed=33)
|
||||
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
# Random top-p values between 0 and 1.
|
||||
p = torch.rand((BATCH_SIZE,))
|
||||
|
||||
# Set p=1 for ~50% of requests in the batch (top-p disabled).
|
||||
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), 1)
|
||||
|
||||
no_op_k = torch.tensor([VOCAB_SIZE])
|
||||
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), k=no_op_k, p=p)
|
||||
|
||||
# Verify that the masked logit's probability sums to at least p.
|
||||
probs.masked_fill_(logits_masked.isinf(), 0)
|
||||
masked_prob_sum = probs.sum(dim=-1)
|
||||
|
||||
torch_xla.sync()
|
||||
|
||||
# Perform assertion on CPU.
|
||||
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
|
||||
|
||||
|
||||
def test_topp_basic():
|
||||
with torch.device(xm.xla_device()):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[math.log(0.2), math.log(0.3), math.log(0.5)],
|
||||
[math.log(0.5), math.log(0.1), math.log(0.4)],
|
||||
]
|
||||
)
|
||||
|
||||
result = apply_top_k_top_p_tpu(
|
||||
logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79])
|
||||
)
|
||||
|
||||
torch_xla.sync()
|
||||
|
||||
# Expect the smallest elements to be dropped.
|
||||
expected_result = logits.clone().cpu()
|
||||
expected_result[0, 0] = float("-inf")
|
||||
expected_result[1, 1] = float("-inf")
|
||||
assert torch.allclose(expected_result, result.cpu())
|
||||
|
||||
|
||||
def test_topp_select_all():
|
||||
with torch.device(xm.xla_device()):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[math.log(0.2), math.log(0.3), math.log(0.5)],
|
||||
[math.log(0.5), math.log(0.1), math.log(0.4)],
|
||||
]
|
||||
)
|
||||
|
||||
result = apply_top_k_top_p_tpu(
|
||||
logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0])
|
||||
)
|
||||
|
||||
torch_xla.sync()
|
||||
|
||||
assert torch.allclose(logits.cpu(), result.cpu())
|
||||
|
||||
|
||||
def test_topp_with_ties():
|
||||
with torch.device(xm.xla_device()):
|
||||
# Input has multiple math.log(0.3).
|
||||
logits = torch.tensor(
|
||||
[[math.log(0.3), math.log(0.3), math.log(0.3), math.log(0.1)]]
|
||||
)
|
||||
|
||||
result = apply_top_k_top_p_tpu(
|
||||
logits=logits.clone(), k=torch.tensor([4]), p=torch.tensor([0.2])
|
||||
)
|
||||
|
||||
torch_xla.sync()
|
||||
|
||||
# All tie values are included in the top-p set. Tie breaking is left
|
||||
# to be done during final sampling (all tie tokens have equal
|
||||
# probability of being chosen).
|
||||
expected_result = logits.clone().cpu()
|
||||
expected_result[0, 3] = float("-inf")
|
||||
assert torch.allclose(expected_result, result.cpu())
|
||||
|
||||
|
||||
def test_both_topk_topp():
|
||||
with torch.device(xm.xla_device()):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[math.log(0.2), math.log(0.3), math.log(0.5)],
|
||||
[math.log(0.5), math.log(0.1), math.log(0.4)],
|
||||
]
|
||||
)
|
||||
|
||||
# Set k=1 for the first batch.
|
||||
result = apply_top_k_top_p_tpu(
|
||||
logits=logits.clone(), k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79])
|
||||
)
|
||||
|
||||
torch_xla.sync()
|
||||
|
||||
# Since for the first batch k=1, expect only the largest element gets
|
||||
# selected.
|
||||
expected_result = logits.clone().cpu()
|
||||
expected_result[0, 0] = float("-inf")
|
||||
expected_result[0, 1] = float("-inf")
|
||||
expected_result[1, 1] = float("-inf")
|
||||
assert torch.allclose(expected_result, result.cpu())
|
||||
@ -1,78 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests whether TPU Int8 computation is enabled correctly.
|
||||
|
||||
Run `pytest tests/quantization/test_tpu_int8.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.tpu_int8 import TPUInt8LinearMethod
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...models.registry import HF_EXAMPLE_MODELS
|
||||
|
||||
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_tpu(), reason="TPU Int8 is only enabled for TPUs."
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
@pytest.mark.parametrize(
|
||||
"hf_overrides",
|
||||
[
|
||||
# w8a8 dynamic activation
|
||||
{
|
||||
"quantization_config": {
|
||||
"quant_method": "tpu_int8",
|
||||
"activation_scheme": "dynamic",
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_model_tpu_int8(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
hf_overrides: dict,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
activation_scheme = hf_overrides.get("quantization_config", {}).get(
|
||||
"activation_scheme"
|
||||
)
|
||||
quantize_activation = activation_scheme == "dynamic"
|
||||
|
||||
# Allows using apply_model
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
# Prevent error from re-initializing cache
|
||||
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
|
||||
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
]
|
||||
answers = [
|
||||
"or kill a human being",
|
||||
]
|
||||
|
||||
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
||||
|
||||
def check_model(model):
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, LinearBase):
|
||||
continue
|
||||
quant_method = module.quant_method
|
||||
assert isinstance(quant_method, TPUInt8LinearMethod)
|
||||
assert quant_method.quantize_activation == quantize_activation
|
||||
|
||||
vllm.apply_model(check_model)
|
||||
outputs = vllm.generate_greedy(prompts, max_tokens)
|
||||
for (_, output), answer in zip(outputs, answers):
|
||||
assert answer in output
|
||||
@ -1,93 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
)
|
||||
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_environment():
|
||||
# This is a fake config used for init dist env.
|
||||
# QKVParallelLinear needs dist env to be initialized.
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
max_model_len=64,
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=4,
|
||||
)
|
||||
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
1,
|
||||
0,
|
||||
local_rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
yield
|
||||
|
||||
|
||||
MESH = None
|
||||
|
||||
|
||||
def _get_spmd_mesh():
|
||||
global MESH
|
||||
if MESH is None:
|
||||
xr.use_spmd()
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
mesh_shape = (num_devices, 1)
|
||||
device_ids = np.array(range(num_devices))
|
||||
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
|
||||
return MESH
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bias", [False, True])
|
||||
# `xr.use_spmd()` will set a global state, and this state is not reversible.
|
||||
# Therefore, non-SPMD tests should be run before SPMD tests.
|
||||
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
|
||||
@pytest.mark.parametrize("device", ["cpu", "xla"])
|
||||
@torch.no_grad()
|
||||
def test_xla_qkv_linear(bias, mesh, device):
|
||||
torch.manual_seed(123)
|
||||
|
||||
qkv_linear = QKVParallelLinear(
|
||||
hidden_size=4096,
|
||||
head_size=128,
|
||||
total_num_heads=32,
|
||||
total_num_kv_heads=8,
|
||||
bias=bias,
|
||||
params_dtype=torch.bfloat16,
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10
|
||||
if bias:
|
||||
qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data)
|
||||
|
||||
xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh)
|
||||
|
||||
qkv_linear = qkv_linear.to(device)
|
||||
xla_qkv_linear = xla_qkv_linear.to(device)
|
||||
input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10
|
||||
input_tensor = input_tensor.to(device)
|
||||
|
||||
output = qkv_linear(input_tensor)
|
||||
xla_output = xla_qkv_linear(input_tensor)
|
||||
assert torch.allclose(output.cpu(), xla_output.cpu())
|
||||
@ -1,587 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
|
||||
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
||||
from vllm.v1.worker.tpu_model_runner import (
|
||||
TPUModelRunner,
|
||||
_get_padded_num_reqs_with_upper_limit,
|
||||
_get_padded_token_len,
|
||||
_get_req_paddings,
|
||||
_get_token_paddings,
|
||||
)
|
||||
|
||||
|
||||
def get_vllm_config():
|
||||
model_config = ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
dtype="bfloat16", # TPUs typically use bfloat16
|
||||
seed=42,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
)
|
||||
return vllm_config
|
||||
|
||||
|
||||
def get_model_runner(vllm_config):
|
||||
device = "xla:0" # Mocking TPU device
|
||||
return TPUModelRunner(vllm_config, device)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_runner():
|
||||
# Patchers have already been started at module level.
|
||||
vllm_config = get_vllm_config()
|
||||
return get_model_runner(vllm_config)
|
||||
|
||||
|
||||
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
new_reqs = []
|
||||
num_scheduled_tokens = {}
|
||||
total_num_scheduled_tokens = 0
|
||||
for req_id in req_ids:
|
||||
new_reqs.append(
|
||||
NewRequestData(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
mm_features=[],
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=PoolingParams(),
|
||||
block_ids=([0],), # block_ids should be tuple[list[int]]
|
||||
num_computed_tokens=0,
|
||||
lora_request=None,
|
||||
)
|
||||
)
|
||||
num_scheduled_tokens[req_id] = 3
|
||||
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
|
||||
|
||||
return SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs,
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
)
|
||||
|
||||
|
||||
def _is_req_scheduled(model_runner, req_id: str) -> bool:
|
||||
return req_id in model_runner.input_batch.req_id_to_index
|
||||
|
||||
|
||||
def _is_req_added(model_runner, req_id: str) -> bool:
|
||||
return req_id in model_runner.requests
|
||||
|
||||
|
||||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
"""Check if the request state block IDs match the block table.
|
||||
|
||||
This function handles both legacy BlockTable and new MultiGroupBlockTable
|
||||
structures for backward compatibility.
|
||||
"""
|
||||
|
||||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
||||
multi_group_block_table = model_runner.input_batch.block_table
|
||||
req_state = model_runner.requests[req_id]
|
||||
|
||||
# Access the first block table from MultiGroupBlockTable
|
||||
# This is safe since we currently only use single KV cache groups
|
||||
block_table = multi_group_block_table[0]
|
||||
|
||||
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
|
||||
# Extract the first group's block IDs
|
||||
if isinstance(req_state.block_ids[0], list):
|
||||
# New format: tuple[list[int], ...] - extract first group
|
||||
req_block_ids = req_state.block_ids[0]
|
||||
else:
|
||||
# Legacy format: list[int] - use directly
|
||||
req_block_ids = req_state.block_ids
|
||||
|
||||
if block_table.num_blocks_per_row[req_index] != len(req_block_ids):
|
||||
return False
|
||||
|
||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
||||
block_table_values = block_table.block_table.np[req_index, :num_blocks]
|
||||
return (block_table_values == req_block_ids).all()
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_finished(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
# finish req
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids={req_id},
|
||||
free_encoder_mm_hashes=[],
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert not _is_req_added(model_runner, req_id)
|
||||
assert not _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_resumed(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
# unschedule req
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert not _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
# resume req
|
||||
cached_req_data = CachedRequestData(
|
||||
req_ids=[req_id],
|
||||
resumed_req_ids={req_id},
|
||||
new_token_ids=[[]],
|
||||
all_token_ids={req_id: scheduler_output.scheduled_new_reqs[0].prompt_token_ids},
|
||||
new_block_ids=[([],)],
|
||||
num_computed_tokens=[0],
|
||||
num_output_tokens=[0],
|
||||
)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=cached_req_data,
|
||||
num_scheduled_tokens={req_id: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_no_changes(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
# schedule req
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens={req_id: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_unscheduled(model_runner):
|
||||
req_ids = ("req_0", "req_1")
|
||||
|
||||
# new reqs
|
||||
scheduler_output = _schedule_new_request(*req_ids)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[0])
|
||||
assert _is_req_scheduled(model_runner, req_ids[0])
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[1])
|
||||
assert _is_req_scheduled(model_runner, req_ids[1])
|
||||
|
||||
# unschedule req_1
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens={req_ids[0]: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[0])
|
||||
assert _is_req_scheduled(model_runner, req_ids[0])
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[1])
|
||||
assert not _is_req_scheduled(model_runner, req_ids[1])
|
||||
|
||||
|
||||
def test_get_paddings():
|
||||
# Bucketed padding
|
||||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
||||
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
|
||||
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
||||
|
||||
# Bucketed padding with max_token_size not a power of two.
|
||||
max_token_size = 317
|
||||
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
|
||||
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
||||
assert actual_paddings == expected_paddings
|
||||
|
||||
# Exponential padding.
|
||||
max_token_size, padding_gap = 1024, 0
|
||||
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
|
||||
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
||||
assert actual_paddings == expected_paddings
|
||||
# Exponential padding with max_token_size not a power of two.
|
||||
max_token_size = 317
|
||||
expected_paddings = [16, 32, 64, 128, 256, 512]
|
||||
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
||||
assert actual_paddings == expected_paddings
|
||||
|
||||
|
||||
def test_get_padded_token_len():
|
||||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
||||
paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
||||
assert _get_padded_token_len(paddings, 1) == 16
|
||||
assert _get_padded_token_len(paddings, 16) == 16
|
||||
assert _get_padded_token_len(paddings, 20) == 32
|
||||
assert _get_padded_token_len(paddings, 300) == 320
|
||||
assert _get_padded_token_len(paddings, 512) == 512
|
||||
|
||||
|
||||
def test_get_padded_num_reqs_with_upper_limit():
|
||||
assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8
|
||||
assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16
|
||||
assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32
|
||||
assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28
|
||||
|
||||
|
||||
def test_get_req_paddings():
|
||||
assert _get_req_paddings(1, 32) == [8, 16, 32]
|
||||
assert _get_req_paddings(8, 32) == [8, 16, 32]
|
||||
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(model_runner):
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
error_msg = f"{layer_1} must come before the current layer"
|
||||
vllm_config = model_runner.vllm_config
|
||||
with (
|
||||
pytest.raises(ValueError, match=error_msg),
|
||||
set_current_vllm_config(vllm_config),
|
||||
):
|
||||
fwd_context = {
|
||||
# initialization below will fail because target layer is invalid;
|
||||
# the target layer needs to come before layer 1
|
||||
layer_0: Attention(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
kv_sharing_target_layer_name=layer_1,
|
||||
),
|
||||
layer_1: Attention(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
),
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
invalid_layer = "model.layers.0.cross_attn.attn"
|
||||
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
|
||||
vllm_config = model_runner.vllm_config
|
||||
with (
|
||||
pytest.raises(ValueError, match=error_msg),
|
||||
set_current_vllm_config(vllm_config),
|
||||
):
|
||||
fwd_context = {
|
||||
layer_0: Attention(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1: Attention(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
# invalid layer: cross_attn.atn doesn't exist!
|
||||
kv_sharing_target_layer_name=invalid_layer,
|
||||
),
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner):
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
error_msg = f"{layer_1} cannot be the same as the current layer"
|
||||
vllm_config = model_runner.vllm_config
|
||||
with (
|
||||
pytest.raises(ValueError, match=error_msg),
|
||||
set_current_vllm_config(vllm_config),
|
||||
):
|
||||
fwd_context = {
|
||||
# initialization below will fail because target layer is invalid;
|
||||
# the target layer needs to come before layer 1
|
||||
layer_0: Attention(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1: Attention(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
kv_sharing_target_layer_name=layer_1,
|
||||
),
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
|
||||
|
||||
def test_init_kv_cache_without_kv_sharing():
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
vllm_config = get_vllm_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
fwd_context = {
|
||||
layer_0: Attention(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1: Attention(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
),
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
# Set high context length to test max context length estimation
|
||||
vllm_config.model_config.max_model_len = 1_000_000
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
model_runner = get_model_runner(vllm_config)
|
||||
kv_cache_spec = model_runner.get_kv_cache_spec()
|
||||
assert len(kv_cache_spec) == 2
|
||||
assert len(model_runner.shared_kv_cache_layers) == 0
|
||||
|
||||
available_memory = 20 * GiB_bytes
|
||||
# page size for each layer KV can be calculated as
|
||||
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
|
||||
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
|
||||
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
|
||||
kv_cache_config = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_spec], [available_memory]
|
||||
)[0]
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 2
|
||||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
|
||||
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
|
||||
|
||||
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||
# max context len with KV sharing should be 2x as large as without
|
||||
# max_context_len = available_memory / (page_size / block_size) / num_caches
|
||||
# max_context_len = 5GB / (512KB / 128) / 2 = 655360
|
||||
assert max_context_len == 655360
|
||||
|
||||
# important: override tensor size to prevent large mem alloc during test
|
||||
# this will only allocate 2 block worth of memory (2 * 512kb)
|
||||
kv_cache_config.num_blocks = 1
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
kv_cache_tensor.size = kv_cache_spec[
|
||||
kv_cache_tensor.shared_by[0]
|
||||
].page_size_bytes
|
||||
|
||||
model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||
# check layer 1 kv cache does NOT share memory with layer 0
|
||||
assert id(layer_1_kv) != id(layer_0_kv)
|
||||
|
||||
# check layer 1 added to kv cache group's layer names
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_valid():
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
vllm_config = get_vllm_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
fwd_context = {
|
||||
layer_0: Attention(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1: Attention(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
|
||||
),
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
# Set high context length to test max context length estimation
|
||||
vllm_config.model_config.max_model_len = 3_000_000
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
model_runner = get_model_runner(vllm_config)
|
||||
kv_cache_spec = model_runner.get_kv_cache_spec()
|
||||
assert len(kv_cache_spec) == 1
|
||||
assert layer_0 in kv_cache_spec
|
||||
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
|
||||
|
||||
available_memory = 20 * GiB_bytes
|
||||
# page size for layer 0's kv_cache_spec is 512KB
|
||||
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
|
||||
# which is twice as many as without KV sharing
|
||||
num_expected_blocks = 2 * 20480 # 20GB / 512KB
|
||||
kv_cache_config = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_spec], [available_memory]
|
||||
)[0]
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 1
|
||||
# Each layer now has twice the available memory for KV cache
|
||||
# compared to no KV sharing
|
||||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
|
||||
|
||||
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||
# max context len with KV sharing should be 2x as large as without
|
||||
assert max_context_len == (2 * 655360)
|
||||
|
||||
# important: override tensor size to prevent large mem alloc during test
|
||||
# this will only allocate 1 block worth of memory (512kb)
|
||||
kv_cache_config.num_blocks = 1
|
||||
kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes
|
||||
|
||||
model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||
# check layer 1 kv cache shares memory with layer 0
|
||||
assert id(layer_1_kv) == id(layer_0_kv)
|
||||
|
||||
# check layer 1 added to kv cache group's layer names
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
|
||||
|
||||
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
|
||||
vllm_config = get_vllm_config()
|
||||
vllm_config.model_config.max_model_len = 32000
|
||||
vllm_config.scheduler_config.max_num_seqs = 1200
|
||||
model_runner = get_model_runner(vllm_config)
|
||||
|
||||
# verify model runner will adjust num_reqs to avoid SMEM OOM.
|
||||
assert model_runner.num_reqs_most_model_len == 1200
|
||||
# num_page_per_req = 32k // 128
|
||||
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
|
||||
assert model_runner.num_reqs_max_model_len == 524
|
||||
@ -66,7 +66,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
|
||||
)
|
||||
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
|
||||
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
|
||||
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
|
||||
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||
|
||||
@ -227,28 +227,3 @@ class MMEncoderAttention(CustomOp):
|
||||
"XPU only supports FLASH_ATTN for vision attention."
|
||||
)
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
|
||||
f"MMEncoderAttention on TPU only supports PALLAS backend, "
|
||||
f"but got {self.attn_backend}."
|
||||
)
|
||||
if cu_seqlens is None:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
return out
|
||||
logger.warning_once(
|
||||
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
|
||||
"Falling back to SDPA implementation.",
|
||||
)
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
@ -1,99 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
USE_RAY = parallel_config = (
|
||||
get_current_vllm_config().parallel_config.distributed_executor_backend == "ray"
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if not USE_TPU_INFERENCE:
|
||||
logger.info("tpu_inference not found, using vLLM's TpuCommunicator")
|
||||
if current_platform.is_tpu():
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
from torch_xla.distributed.xla_multiprocessing import (
|
||||
create_optimized_replica_groups,
|
||||
)
|
||||
|
||||
if USE_RAY:
|
||||
from vllm.v1.executor import ray_utils
|
||||
|
||||
|
||||
class TpuCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
|
||||
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
|
||||
# must be used together. Therefore, the local rank and world size can
|
||||
# be simply calculated as follows.
|
||||
global_rank = self.global_rank
|
||||
global_world_size = self.global_world_size
|
||||
|
||||
if USE_RAY:
|
||||
logger.info("TpuCommunicator initialized with RAY")
|
||||
# Calculate how many TPU nodes are in the current deployment. This
|
||||
# is the Ray placement group if it is deployed with Ray. Default
|
||||
# to the number of TPU nodes in the Ray cluster. The number of TPU
|
||||
# nodes is computed by the total number of TPUs divided by the
|
||||
# number of TPU accelerators per node, to account for clusters
|
||||
# with both CPUs and TPUs.
|
||||
num_nodes = ray_utils.get_num_tpu_nodes()
|
||||
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
|
||||
if num_nodes_in_pg > 0:
|
||||
num_nodes = num_nodes_in_pg
|
||||
|
||||
local_world_size = global_world_size // num_nodes
|
||||
local_rank = global_rank % local_world_size
|
||||
else:
|
||||
logger.info("TpuCommunicator initialized with MP")
|
||||
# Sanity: Verify we run on a single host
|
||||
num_hosts = torch_xla.tpu.num_tpu_workers()
|
||||
assert num_hosts == 1
|
||||
|
||||
# Get the current number of TPUs (we have locally)
|
||||
local_world_size = torch_xla.tpu.num_available_chips()
|
||||
|
||||
# Get current rank
|
||||
local_rank = global_rank % local_world_size
|
||||
|
||||
# Ensure environment variables are set for multihost deployments.
|
||||
# On GKE, this is needed for libtpu and TPU driver to know which TPU
|
||||
# chip is actually visible. Otherwise the TPU driver will fail to
|
||||
# initialize because the number of devices would be different from
|
||||
# the number of visible worker addresses.
|
||||
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
|
||||
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
|
||||
|
||||
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
||||
xr._init_world_size_ordinal()
|
||||
self.groups = create_optimized_replica_groups()
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# TODO: Remove the groups specification after XLA compiler can support
|
||||
# auto-reordering the ring order for all-reduce.
|
||||
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||
return xm.all_gather(input_, dim=dim)
|
||||
@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Literal
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
@ -251,9 +250,6 @@ class TpKVTopology:
|
||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||
)
|
||||
|
||||
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||
|
||||
@property
|
||||
def is_kv_layout_blocks_first(self) -> bool:
|
||||
return self._is_kv_layout_blocks_first
|
||||
@ -261,7 +257,7 @@ class TpKVTopology:
|
||||
@property
|
||||
def split_k_and_v(self) -> bool:
|
||||
# Whether to register regions for K and V separately (when present).
|
||||
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
|
||||
return not (self.is_mla or self.is_kv_layout_blocks_first)
|
||||
|
||||
@property
|
||||
def tp_size(self) -> int:
|
||||
|
||||
@ -499,7 +499,6 @@ class MooncakeConnectorWorker:
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
attn_backend=backend,
|
||||
)
|
||||
self._use_pallas = self.kv_topo._use_pallas
|
||||
|
||||
self.zmq_ctx = zmq.Context()
|
||||
self.async_zmq_ctx = zmq.asyncio.Context()
|
||||
|
||||
@ -983,7 +983,6 @@ class NixlConnectorWorker:
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
attn_backend=backend,
|
||||
)
|
||||
self._use_pallas = self.kv_topo._use_pallas
|
||||
self._physical_blocks_per_logical_kv_block = 1
|
||||
|
||||
def _nixl_handshake(
|
||||
@ -1641,9 +1640,6 @@ class NixlConnectorWorker:
|
||||
# Num kv_heads > tp_size and P TP > D TP case, not supported
|
||||
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
|
||||
|
||||
assert not self._use_pallas or tp_ratio == 1, (
|
||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||
)
|
||||
kv_cache_layout = (
|
||||
self.kv_cache_layout
|
||||
if not self.use_host_buffer
|
||||
@ -1814,9 +1810,7 @@ class NixlConnectorWorker:
|
||||
|
||||
if len(self.device_kv_caches) == 0:
|
||||
return
|
||||
split_k_and_v = not (
|
||||
self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
|
||||
)
|
||||
split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
|
||||
sample_cache = list(self.device_kv_caches.values())[0][0]
|
||||
for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
|
||||
assert block_size_ratio > 1, "Only nP < nD supported currently."
|
||||
|
||||
@ -1,188 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_xla.distributed.spmd as xs
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XlaQKVParallelLinear(nn.Module):
|
||||
def __init__(self, qkv_linear: nn.Module, mesh: Optional["xs.Mesh"] = None):
|
||||
super().__init__()
|
||||
assert isinstance(qkv_linear, QKVParallelLinear)
|
||||
self.skip_bias_add = qkv_linear.skip_bias_add
|
||||
self.return_bias = qkv_linear.return_bias
|
||||
assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."
|
||||
|
||||
self.q_weight: Parameter
|
||||
self.k_weight: Parameter
|
||||
self.v_weight: Parameter
|
||||
self.q_bias: Parameter | None
|
||||
self.k_bias: Parameter | None
|
||||
self.v_bias: Parameter | None
|
||||
self._load_weights_from_qkv_linear(qkv_linear)
|
||||
if mesh is not None:
|
||||
self._shard_weight(mesh)
|
||||
|
||||
def _shard_weight(self, mesh: "xs.Mesh"):
|
||||
self.q_weight = Parameter(self.q_weight.to("xla"), requires_grad=False)
|
||||
self.k_weight = Parameter(self.k_weight.to("xla"), requires_grad=False)
|
||||
self.v_weight = Parameter(self.v_weight.to("xla"), requires_grad=False)
|
||||
xs.mark_sharding(self.q_weight, mesh, ("x", None))
|
||||
xs.mark_sharding(self.k_weight, mesh, ("x", None))
|
||||
xs.mark_sharding(self.v_weight, mesh, ("x", None))
|
||||
if self.q_bias is not None:
|
||||
assert self.k_bias is not None and self.v_bias is not None, (
|
||||
"QKVParallelLinear should have q, k, and v biases together."
|
||||
)
|
||||
self.q_bias = Parameter(self.q_bias.to("xla"), requires_grad=False)
|
||||
xs.mark_sharding(self.q_bias, mesh, ("x",))
|
||||
self.k_bias = Parameter(self.k_bias.to("xla"), requires_grad=False)
|
||||
xs.mark_sharding(self.k_bias, mesh, ("x",))
|
||||
self.v_bias = Parameter(self.v_bias.to("xla"), requires_grad=False)
|
||||
xs.mark_sharding(self.v_bias, mesh, ("x",))
|
||||
|
||||
def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
|
||||
q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
|
||||
# The weight of qkv linear is a concatenation of q, k, and v weights
|
||||
# along the output dimension.
|
||||
qkv_weight = qkv_linear.weight.data.cpu()
|
||||
q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
|
||||
k_weight = Parameter(
|
||||
qkv_weight[q_proj_size : q_proj_size + k_proj_size], requires_grad=False
|
||||
)
|
||||
v_weight = Parameter(
|
||||
qkv_weight[q_proj_size + k_proj_size :], requires_grad=False
|
||||
)
|
||||
self.register_parameter("q_weight", q_weight)
|
||||
self.register_parameter("k_weight", k_weight)
|
||||
self.register_parameter("v_weight", v_weight)
|
||||
|
||||
if qkv_linear.bias is not None:
|
||||
q_bias = Parameter(qkv_linear.bias[:q_proj_size], requires_grad=False)
|
||||
k_bias = Parameter(
|
||||
qkv_linear.bias[q_proj_size : q_proj_size + k_proj_size],
|
||||
requires_grad=False,
|
||||
)
|
||||
v_bias = Parameter(
|
||||
qkv_linear.bias[q_proj_size + k_proj_size :], requires_grad=False
|
||||
)
|
||||
self.register_parameter("q_bias", q_bias)
|
||||
self.register_parameter("k_bias", k_bias)
|
||||
self.register_parameter("v_bias", v_bias)
|
||||
else:
|
||||
self.register_parameter("q_bias", None)
|
||||
self.register_parameter("k_bias", None)
|
||||
self.register_parameter("v_bias", None)
|
||||
|
||||
def forward(self, input):
|
||||
# Same forward functionality as QKVParallelLinear, but doing qkv porj
|
||||
# separately.
|
||||
q_bias = self.q_bias if not self.skip_bias_add else None
|
||||
k_bias = self.k_bias if not self.skip_bias_add else None
|
||||
v_bias = self.v_bias if not self.skip_bias_add else None
|
||||
q_proj = F.linear(input, self.q_weight, q_bias)
|
||||
k_proj = F.linear(input, self.k_weight, k_bias)
|
||||
v_proj = F.linear(input, self.v_weight, v_bias)
|
||||
# The q/k/v projections will be split outside of the QKVParallelLinear.
|
||||
# Because we are replacing XlaQKVParallelLinear with the
|
||||
# QKVParallelLinear, we need to concatenate q, k, and v projections to
|
||||
# match the output shape of the QKVParallelLinear implementation even if
|
||||
# it seems to be redundant.
|
||||
# The concat and the following split will be noop, and should be
|
||||
# optimized away by the compiler.
|
||||
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
|
||||
output_bias = (
|
||||
torch.cat([q_bias, k_bias, v_bias], dim=-1) if self.skip_bias_add else None
|
||||
)
|
||||
if not self.return_bias:
|
||||
return qkv_proj
|
||||
return qkv_proj, output_bias
|
||||
|
||||
|
||||
def partition_column_parallel_linear(
|
||||
layer: torch.nn.Module, mesh: xs.Mesh
|
||||
) -> torch.nn.Module:
|
||||
assert isinstance(layer, ColumnParallelLinear)
|
||||
xs.mark_sharding(layer.weight, mesh, ("x", None))
|
||||
logger.debug("Applied column-parallel sharding to %s", layer)
|
||||
return layer
|
||||
|
||||
|
||||
def partition_row_parallel_linear(
|
||||
layer: torch.nn.Module, mesh: xs.Mesh
|
||||
) -> torch.nn.Module:
|
||||
assert isinstance(layer, RowParallelLinear)
|
||||
xs.mark_sharding(layer.weight, mesh, (None, "x"))
|
||||
logger.debug("Applied row-parallel sharding to %s", layer)
|
||||
return layer
|
||||
|
||||
|
||||
def partition_qkv_parallel_linear(
|
||||
layer: torch.nn.Module, mesh: xs.Mesh
|
||||
) -> torch.nn.Module:
|
||||
assert isinstance(layer, QKVParallelLinear)
|
||||
xla_layer = XlaQKVParallelLinear(layer, mesh)
|
||||
logger.debug("Applied qkv parallel sharding to %s", layer)
|
||||
return xla_layer
|
||||
|
||||
|
||||
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict(
|
||||
[
|
||||
("QKVParallelLinear", partition_qkv_parallel_linear),
|
||||
("ColumnParallelLinear", partition_column_parallel_linear),
|
||||
("RowParallelLinear", partition_row_parallel_linear),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_fqn(module):
|
||||
# Get the fully qualified name of the module
|
||||
return module.__class__.__qualname__
|
||||
|
||||
|
||||
def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
|
||||
"""
|
||||
Recursively check a PyTorch model and apply appropriate sharding based on
|
||||
the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
|
||||
|
||||
Args:
|
||||
model: torch.nn.Module to process
|
||||
mesh: An XLA SPMD mesh object used for sharding
|
||||
"""
|
||||
|
||||
def _process_module(module, name=None, parent=None):
|
||||
for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
|
||||
if get_fqn(module) == module_type:
|
||||
wrapped_module = wrapping_func(module, mesh)
|
||||
|
||||
assert parent is not None and name is not None, (
|
||||
"Top Level module is not expected to be wrapped."
|
||||
)
|
||||
if wrapped_module is not module:
|
||||
# Wrapped module and module are different py object.
|
||||
# The original module should be replaced by the
|
||||
# wrapped_module.
|
||||
logger.debug("replace %s with %s", module, wrapped_module)
|
||||
setattr(parent, name, wrapped_module)
|
||||
|
||||
module = wrapped_module
|
||||
break
|
||||
|
||||
for child_name, child_module in list(module.named_children()):
|
||||
_process_module(child_module, child_name, module)
|
||||
|
||||
_process_module(model)
|
||||
@ -1,6 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
|
||||
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]
|
||||
@ -1,141 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_xla.core.xla_builder as xb
|
||||
from torch.library import impl
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
|
||||
|
||||
|
||||
@jax.jit
|
||||
def bgmv_jax(inputs, loras, idxs):
|
||||
return jnp.einsum(
|
||||
"td,tX,Xld->tl",
|
||||
inputs,
|
||||
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
|
||||
loras,
|
||||
)
|
||||
|
||||
|
||||
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "XLA")
|
||||
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
|
||||
jax_import_guard()
|
||||
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
|
||||
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
||||
T, _ = inputs.shape
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
_, L, _ = loras.shape
|
||||
|
||||
return torch.empty((T, L), device=inputs.device)
|
||||
|
||||
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
[num_tokens, hidden_size * num_slices].
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
tensor.
|
||||
"""
|
||||
|
||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
if output_tensor.shape[1] > outputs.shape[1]:
|
||||
outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
|
||||
|
||||
if add_inputs:
|
||||
return output_tensor + outputs[:limit, : output_tensor.shape[1]]
|
||||
else:
|
||||
return outputs[:limit, : output_tensor.shape[1]]
|
||||
|
||||
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
scaling (float, optional): Scalar multiplier applied to the output.
|
||||
"""
|
||||
|
||||
return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
|
||||
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
[num_tokens, hidden_size * num_slices].
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
tensor.
|
||||
"""
|
||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
|
||||
outputs = F.pad(
|
||||
outputs,
|
||||
(
|
||||
slice_offset,
|
||||
output_tensor.shape[1] - (slice_offset + slice_size),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
|
||||
if add_inputs:
|
||||
return output_tensor + outputs
|
||||
else:
|
||||
return outputs
|
||||
@ -1,358 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_xla
|
||||
|
||||
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
from vllm.lora.punica_wrapper.utils import convert_mapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
|
||||
class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperTPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_batches: int,
|
||||
device: torch.device | str,
|
||||
**kwargs,
|
||||
):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
||||
|
||||
# PunicaWrapperBase defines some tensors with dtype=torch.int64, which
|
||||
# isn't supported by the TPU. So convert those tensors to int32.
|
||||
# Not all of them are used by the TPU so only convert the useful ones.
|
||||
self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32)
|
||||
self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
|
||||
self._sampler_indices_padded = self._sampler_indices_padded.to(
|
||||
dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True)
|
||||
|
||||
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
|
||||
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
|
||||
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
|
||||
|
||||
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
|
||||
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
|
||||
|
||||
@property
|
||||
def embeddings_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for lora embeddings,
|
||||
specifically for VocabParallelEmbeddingWithLoRA.
|
||||
"""
|
||||
return self._embeddings_indices[:]
|
||||
|
||||
@property
|
||||
def sampler_indices_padded(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to padded sampler indices.
|
||||
"""
|
||||
return self._sampler_indices_padded[:]
|
||||
|
||||
def shrink(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
|
||||
|
||||
def expand(
|
||||
self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool
|
||||
):
|
||||
return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs)
|
||||
|
||||
def expand_slice(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
) -> torch.Tensor:
|
||||
return bgmv_expand_slice(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
self._get_token_lora_indices(x),
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def add_shrink(
|
||||
self,
|
||||
y: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(y, True)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
lora_s = lora_a_stacked[slice_idx]
|
||||
y_s = self.shrink(x, lora_s, scale)
|
||||
y[slice_idx, :, :] = y_s # type: ignore[index]
|
||||
return y
|
||||
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
output_slices: tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||
output_slices (tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = 0
|
||||
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
y = self.expand_slice(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_left += output_slices[slice_idx]
|
||||
return y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
# Embedding layer only needs the expand op
|
||||
return self.expand(y, x, lora_b_stacked, add_inputs)
|
||||
|
||||
def add_lora_linear(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: tuple[int, ...],
|
||||
*,
|
||||
buffer: tuple[torch.Tensor, ...] | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will not be changed in-place.
|
||||
x (torch.Tensor): Input tensor (T, E)
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
||||
"""
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
T = x.size(0)
|
||||
buffer = torch.zeros(
|
||||
(len(output_slices), T, r),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
return self.add_expand(
|
||||
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
|
||||
)
|
||||
|
||||
def add_lora_logits(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]):Default to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
|
||||
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
|
||||
y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
|
||||
return y.view_as(y_org)
|
||||
|
||||
# This performs the same tensor ops as the base method, except it does them
|
||||
# on the CPU then transfers the results to the TPU
|
||||
def _update_base_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
# Make sure we don't accidentally collect outside operations
|
||||
torch_xla.sync()
|
||||
|
||||
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
||||
# TODO: Should this happen inside mapping internally? If so how can we
|
||||
# avoid having backend specific LoRAMapping classes?
|
||||
mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping)
|
||||
|
||||
(
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
indices_len,
|
||||
) = convert_mapping(
|
||||
mapping,
|
||||
lora_index_to_id,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
0, # extra_vocab_size
|
||||
"cpu",
|
||||
)
|
||||
self._token_lora_indices = self._pad_to_shape(
|
||||
base_indices, self._token_lora_indices.shape, dims=1
|
||||
).to(self.device)
|
||||
self._sampler_indices = self._pad_to_shape(
|
||||
sampler_indices, self._sampler_indices.shape, dims=1
|
||||
).to(self.device)
|
||||
self._sampler_indices_padded = self._pad_to_shape(
|
||||
sampler_indices_padded, self._sampler_indices_padded.shape, dims=1
|
||||
).to(self.device)
|
||||
self._embeddings_indices = self._pad_to_shape(
|
||||
embeddings_indices, self._embeddings_indices.shape, dims=2
|
||||
).to(self.device)
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
|
||||
self.batch_size = 1
|
||||
self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[
|
||||
: self.batch_size
|
||||
]
|
||||
|
||||
def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
|
||||
num_reqs = len(prompt_mapping)
|
||||
|
||||
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
|
||||
# import
|
||||
MIN_NUM_SEQS = 8
|
||||
|
||||
padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
|
||||
pad_len = padded_num_reqs - num_reqs
|
||||
|
||||
padding = [-1] * pad_len
|
||||
return tuple(list(prompt_mapping) + padding)
|
||||
|
||||
def _pad_to_shape(self, src, target_shape, dims=1):
|
||||
if dims == 1:
|
||||
pad_len = target_shape[0] - src.shape[0]
|
||||
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
|
||||
else:
|
||||
pad_rows = target_shape[0] - src.shape[0]
|
||||
pad_cols = target_shape[1] - src.shape[1]
|
||||
return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32)
|
||||
@ -67,21 +67,15 @@ else:
|
||||
|
||||
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_grouped_topk,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||
FusedMoEModularMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_grouped_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
|
||||
@ -1,83 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
|
||||
"""
|
||||
Compute the histogram of an int32 tensor. The bin edges are defined by the
|
||||
min and max values, with step = 1.
|
||||
"""
|
||||
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
|
||||
assert min <= max, "min must be less than or equal to max."
|
||||
|
||||
def searchsorted(
|
||||
sorted_sequence: torch.Tensor, values_to_search: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
|
||||
|
||||
bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to(
|
||||
input.device
|
||||
)
|
||||
return searchsorted(bin_edges, input).to(torch.int32)
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
renormalize: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [*, hidden_size]
|
||||
w1: [num_experts, intermediate_size * 2, hidden_size]
|
||||
w2: [num_experts, hidden_size, intermediate_size]
|
||||
gating_output: [*, num_experts]
|
||||
"""
|
||||
assert expert_map is None, "expert_map is not supported for pallas MoE."
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
num_experts = w1.shape[0]
|
||||
intermediate_size = w2.shape[-1]
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
assert (num_tokens * topk) % 16 == 0, (
|
||||
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
|
||||
f"16 but got {num_tokens * topk}"
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, num_experts)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
|
||||
topk_indices = topk_indices.flatten()
|
||||
topk_argsort_indices = topk_indices.argsort()
|
||||
topk_argsort_revert_indices = topk_argsort_indices.argsort()
|
||||
token_indices = torch.arange(num_tokens, device=device).repeat_interleave(topk)
|
||||
token_indices = token_indices[topk_argsort_indices]
|
||||
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
|
||||
|
||||
x = hidden_states[token_indices]
|
||||
x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True)
|
||||
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
|
||||
x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True)
|
||||
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
||||
|
||||
x = x * topk_weights.unsqueeze(dim=-1)
|
||||
x = x.sum(dim=-2)
|
||||
x = x.reshape(orig_shape)
|
||||
return x
|
||||
@ -38,10 +38,6 @@ if current_platform.is_cuda_alike():
|
||||
else:
|
||||
TritonExperts = None # type: ignore
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -403,53 +399,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not layer.use_grouped_topk
|
||||
assert layer.num_expert_group is None
|
||||
assert layer.topk_group is None
|
||||
assert layer.custom_routing_function is None
|
||||
assert layer.apply_router_weight_on_input is False
|
||||
if layer.scoring_func != "softmax":
|
||||
raise NotImplementedError(
|
||||
"Only softmax scoring function is supported for TPU."
|
||||
)
|
||||
if layer.e_score_correction_bias is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert score correction bias is not supported for TPU."
|
||||
)
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} is not supported for TPU."
|
||||
)
|
||||
assert layer.routed_scaling_factor == 1.0, (
|
||||
f"routed_scaling_factor {layer.routed_scaling_factor} is "
|
||||
"not supported for TPU."
|
||||
)
|
||||
if (
|
||||
layer.enable_eplb is not False
|
||||
or layer.expert_load_view is not None
|
||||
or layer.logical_to_physical_map is not None
|
||||
or layer.logical_replica_count is not None
|
||||
):
|
||||
raise NotImplementedError("Expert load balancing is not supported for TPU.")
|
||||
return fused_moe_pallas(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk=layer.top_k,
|
||||
gating_output=router_logits,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
renormalize=layer.renormalize,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
forward_native = forward_tpu
|
||||
elif current_platform.is_cpu():
|
||||
if current_platform.is_cpu():
|
||||
forward_native = forward_cpu
|
||||
elif current_platform.is_xpu():
|
||||
forward_native = forward_xpu
|
||||
|
||||
@ -11,7 +11,6 @@ logger = init_logger(__name__)
|
||||
QuantizationMethods = Literal[
|
||||
"awq",
|
||||
"deepspeedfp",
|
||||
"tpu_int8",
|
||||
"fp8",
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
@ -130,12 +129,10 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .rtn import RTNConfig
|
||||
from .torchao import TorchAOConfig
|
||||
from .tpu_int8 import Int8TpuConfig
|
||||
|
||||
method_to_config: dict[str, type[QuantizationConfig]] = {
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fp8": Fp8Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"fp_quant": FPQuantConfig,
|
||||
|
||||
@ -19,9 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||
TritonScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||
XLAScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
@ -29,7 +26,6 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,106 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from functorch.experimental.control_flow import cond # noqa: F401
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_tpu():
|
||||
return False, "Requires TPU."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_tpu():
|
||||
return False, "ScaledMMXLA requires running on TPU."
|
||||
|
||||
if c.is_static_input_scheme:
|
||||
return False, "ScaledMMXLA requires dynamic activation scales."
|
||||
|
||||
if not c.input_symmetric:
|
||||
return False, "ScaledMMXLA requires symmetric activation scales."
|
||||
|
||||
if not c.is_channelwise:
|
||||
return False, "ScaledMMXLA requires channelwise weight scales"
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# [out, in] (different than cutlass_scaled_mm)
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# XLA kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
|
||||
# [out_channel,] (different than cutlass_scaled_mm)
|
||||
weight_scale = weight_scale.squeeze(-1)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# Only support symmetric dynamic activation quantization.
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
# Filter warning for cond usage in apply_weights. It is okay
|
||||
# to specialize the graph since bias is not dynamic.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501
|
||||
)
|
||||
|
||||
def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
|
||||
return x
|
||||
|
||||
def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
|
||||
return x + bias
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
|
||||
# Required to register custom ops.
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
|
||||
out = torch.ops.xla.quantized_matmul_int8(
|
||||
x,
|
||||
w_q,
|
||||
w_s,
|
||||
quantize_activation=True,
|
||||
)
|
||||
|
||||
# Explicitly capture control flow to make dynamo happy.
|
||||
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
||||
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
||||
@ -1,139 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
QuantizationMethods,
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
|
||||
ACTIVATION_SCHEMES = ["none", "dynamic"]
|
||||
|
||||
|
||||
class Int8TpuConfig(QuantizationConfig):
|
||||
"""Int8 Quantization Config class for TPU Backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "none",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if activation_scheme not in ACTIVATION_SCHEMES:
|
||||
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
||||
self.activation_scheme = activation_scheme
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "tpu_int8"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError("This function should not be called with TPU Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
return cls(activation_scheme=activation_scheme)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: Module, prefix: str
|
||||
) -> Optional["TPUInt8LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return TPUInt8LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class TPUInt8LinearMethod(LinearMethodBase):
|
||||
"""Int8 Linear method for TPU Quant."""
|
||||
|
||||
def __init__(self, quant_config: Int8TpuConfig):
|
||||
self.quant_config = quant_config
|
||||
self.quantize_activation = False
|
||||
if self.quant_config.activation_scheme == "dynamic":
|
||||
self.quantize_activation = True
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def _quantize_weight(
|
||||
self, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
weight_dtype = weight.dtype
|
||||
weight = weight.cpu().to(torch.float32)
|
||||
n_bit = 8
|
||||
eps = 1e-5
|
||||
max_int = 2 ** (n_bit - 1) - 1
|
||||
min_int = -(2 ** (n_bit - 1))
|
||||
max_val = weight.abs().amax(dim=-1, keepdim=True)
|
||||
max_val = max_val.clamp(min=eps)
|
||||
qscale = max_val / max_int
|
||||
qweight = torch.clamp(
|
||||
torch.round(weight * (1.0 / qscale)), min_int, max_int
|
||||
).to(torch.int8)
|
||||
qscale = qscale.squeeze().to(weight_dtype)
|
||||
return qweight, qscale
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
device = layer.weight.device
|
||||
qweight, qscale = self._quantize_weight(layer.weight)
|
||||
qweight = qweight.to(device)
|
||||
qscale = qscale.to(device)
|
||||
layer.weight = Parameter(qweight, requires_grad=False)
|
||||
layer.scale = Parameter(qscale, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install torch_xla by following the instructions at "
|
||||
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
|
||||
"to run vLLM on TPU."
|
||||
) from err
|
||||
weight = layer.weight
|
||||
scale = layer.scale
|
||||
out = torch.ops.xla.quantized_matmul_int8(
|
||||
x, weight, scale, quantize_activation=self.quantize_activation
|
||||
)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
@ -30,7 +30,6 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -241,22 +240,6 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
||||
|
||||
if not USE_TPU_INFERENCE:
|
||||
# In PyTorch XLA, we should call `torch_xla.sync`
|
||||
# frequently so that not too many ops are accumulated
|
||||
# in the XLA program.
|
||||
import torch_xla
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
# Apply the prefix.
|
||||
|
||||
@ -1,118 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.spmd as xs
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUModelLoader(DefaultModelLoader):
|
||||
"""
|
||||
A TPU model loader for model loading under SPMD mode.
|
||||
"""
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig,
|
||||
mesh: xs.Mesh | None = None,
|
||||
) -> nn.Module:
|
||||
# Initialize model and load weights on CPU. Then, during SPMD partition,
|
||||
# weights are sharded and transferred to TPUs.
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
model_config = vllm_config.model_config
|
||||
assert model_config.quantization is None, "Quantization not supported"
|
||||
target_device = torch.device("cpu")
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
load_format = vllm_config.load_config.load_format
|
||||
if load_format != "dummy":
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
all_weights = self.get_all_weights(model_config, model)
|
||||
loaded_weights = model.load_weights(all_weights)
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights
|
||||
- self.counter_before_loading_weights,
|
||||
)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}"
|
||||
)
|
||||
else:
|
||||
logger.info("Use dummy weight during weight loading.")
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
counter_before_partition = time.perf_counter()
|
||||
model = model.eval()
|
||||
model = model.to("xla")
|
||||
shard_model(model, mesh)
|
||||
counter_after_partition = time.perf_counter()
|
||||
logger.info(
|
||||
"Partition model took %.2f seconds",
|
||||
counter_after_partition - counter_before_partition,
|
||||
)
|
||||
|
||||
# Ensure the model is properly loaded.
|
||||
self._check_model_is_loaded(mesh, model)
|
||||
|
||||
# Need to torch compile after model sharding are done. Because the
|
||||
# compiler hints ('xs.mark_sharding') are torch ops.
|
||||
if not model_config.is_multimodal_model:
|
||||
model.model = torch.compile(model.model, backend="openxla")
|
||||
else:
|
||||
model.language_model.model = torch.compile(
|
||||
model.language_model.model, backend="openxla"
|
||||
)
|
||||
return model
|
||||
|
||||
def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None:
|
||||
"""
|
||||
Ensure the model is properly loaded.
|
||||
1. All model parameters and buffers are on XLA device.
|
||||
2. Non-SPMD friendly layers are replaced as expected.
|
||||
"""
|
||||
device = xm.xla_device()
|
||||
device_type = str(device.type)
|
||||
|
||||
# Check parameters
|
||||
for name, param in model.named_parameters():
|
||||
assert param.device.type == device_type, (
|
||||
f"Parameter {name} is on {param.device.type} instead of {device_type}"
|
||||
)
|
||||
|
||||
# Check buffers
|
||||
for name, buffer in model.named_buffers():
|
||||
assert buffer.device.type == device_type, (
|
||||
f"Buffer {name} is on {buffer.device.type} instead of {device_type}"
|
||||
)
|
||||
|
||||
for module in model.modules():
|
||||
if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"):
|
||||
raise AssertionError(
|
||||
"QKVParallelLinear should be replaced by \
|
||||
XlaQKVParallelLinear under SPMD mode."
|
||||
)
|
||||
@ -1,287 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
import torch
|
||||
from tpu_info import device
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import BlockSize
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
ParamsType: TypeAlias = SamplingParams | PoolingParams
|
||||
else:
|
||||
BlockSize = None
|
||||
VllmConfig = None
|
||||
PoolingParams = None
|
||||
ParamsType = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
USE_TPU_INFERENCE = False
|
||||
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
device_name: str = "tpu"
|
||||
device_type: str = "tpu"
|
||||
dispatch_key: str = "XLA"
|
||||
ray_device_key: str = "TPU"
|
||||
dist_backend: str = "gloo"
|
||||
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
||||
simple_compile_backend: str = "openxla"
|
||||
|
||||
supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
|
||||
|
||||
additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
|
||||
|
||||
@classmethod
|
||||
def import_kernels(cls) -> None:
|
||||
# Do not import vllm._C
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
if attn_selector_config.use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on TPU.")
|
||||
if selected_backend != AttentionBackendEnum.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
logger.info("Using Pallas V1 backend.")
|
||||
return AttentionBackendEnum.PALLAS.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.PALLAS,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention"
|
||||
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention.")
|
||||
return backend
|
||||
|
||||
logger.info_once(
|
||||
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
|
||||
)
|
||||
return AttentionBackendEnum.PALLAS
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.tpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
chip_type, _ = device.get_local_chips()
|
||||
return f"TPU {chip_type.name}"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
|
||||
|
||||
@classmethod
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
|
||||
return torch.finfo(dtype).min, torch.finfo(dtype).max
|
||||
|
||||
@classmethod
|
||||
def can_update_inplace(cls):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_lora_vocab_padding_size(cls) -> int:
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
from vllm.config import CompilationMode, CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
# For v0, the default block size is 16.
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = cast(BlockSize, 16)
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
# TPU only supports DYNAMO_TRACE_ONCE compilation mode
|
||||
if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
|
||||
logger.info(
|
||||
"[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\
|
||||
disabling cudagraph."
|
||||
)
|
||||
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
||||
|
||||
if (
|
||||
compilation_config.cudagraph_mode is None
|
||||
or compilation_config.cudagraph_mode.max_cudagraph_mode()
|
||||
!= CUDAGraphMode.NONE
|
||||
):
|
||||
logger.info(
|
||||
"[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
if compilation_config.backend == "":
|
||||
compilation_config.backend = "openxla"
|
||||
|
||||
assert vllm_config.speculative_config is None, (
|
||||
"TPU does not support speculative decoding"
|
||||
)
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
if model_config is not None and model_config.dtype in (
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"The TPU backend currently does not support %s. "
|
||||
"Using bfloat16 instead.",
|
||||
model_config.dtype,
|
||||
)
|
||||
model_config.dtype = torch.bfloat16
|
||||
|
||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
||||
|
||||
cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment]
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
|
||||
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend"
|
||||
)
|
||||
|
||||
if (
|
||||
scheduler_config.is_multimodal_model
|
||||
and not scheduler_config.disable_chunked_mm_input
|
||||
):
|
||||
logger.warning(
|
||||
"TPU does not support running Multimodal models"
|
||||
" without setting `--disable_chunked_mm_input`. "
|
||||
"Forcing --disable_chunked_mm_input."
|
||||
)
|
||||
scheduler_config.disable_chunked_mm_input = True
|
||||
|
||||
if model_config and model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled."
|
||||
)
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.model_config.max_model_len,
|
||||
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
logger.warning("Pin memory is not supported on TPU.")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def validate_request(
|
||||
cls,
|
||||
prompt: PromptType,
|
||||
params: ParamsType,
|
||||
processed_inputs: ProcessorInputs,
|
||||
) -> None:
|
||||
"""Raises if this request is unsupported on this platform"""
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
|
||||
if (
|
||||
isinstance(params, SamplingParams)
|
||||
and params.sampling_type == SamplingType.RANDOM_SEED
|
||||
):
|
||||
raise ValueError("Torch XLA does not support per-request seed.")
|
||||
|
||||
@classmethod
|
||||
@torch.compile(backend="openxla")
|
||||
def insert_blocks_to_device(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
|
||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
@torch.compile(backend="openxla")
|
||||
def swap_out_blocks_to_host(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""tpu blocks to cpu blocks"""
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
|
||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
|
||||
|
||||
@classmethod
|
||||
def use_sync_weight_loader(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_max_model_len(cls, max_model_len: int) -> int:
|
||||
"""
|
||||
Check max_model_len for the current platform.
|
||||
"""
|
||||
logger.warning(
|
||||
"--max-model-len is not specified, "
|
||||
"it's currently using model's default length %d, "
|
||||
"which might be too large."
|
||||
"Please input with --max-model-len based on your "
|
||||
"request input length and output length, to avoid "
|
||||
"unnecessary degradation.",
|
||||
max_model_len,
|
||||
)
|
||||
return max_model_len
|
||||
|
||||
|
||||
try:
|
||||
from tpu_inference.platforms import (
|
||||
@ -291,5 +14,7 @@ try:
|
||||
TpuPlatform = TpuInferencePlatform # type: ignore
|
||||
USE_TPU_INFERENCE = True
|
||||
except ImportError:
|
||||
logger.info("tpu_inference not found, using vLLM's TpuPlatform")
|
||||
logger.error(
|
||||
"tpu_inference not found, please install tpu_inference to run vllm on TPU"
|
||||
)
|
||||
pass
|
||||
|
||||
@ -186,20 +186,6 @@ class UsageMessage:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _report_torch_xla_usage(self) -> bool:
|
||||
try:
|
||||
import torch_xla
|
||||
|
||||
self.gpu_count = torch_xla.runtime.world_size()
|
||||
self.gpu_type = torch_xla.tpu.get_tpu_type()
|
||||
self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[
|
||||
"bytes_limit"
|
||||
]
|
||||
self.cuda_runtime = "torch_xla"
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _report_usage_once(
|
||||
self,
|
||||
model_architecture: str,
|
||||
@ -217,9 +203,7 @@ class UsageMessage:
|
||||
if current_platform.is_cuda():
|
||||
self.cuda_runtime = torch.version.cuda
|
||||
if current_platform.is_tpu(): # noqa: SIM102
|
||||
if (not self._report_tpu_inference_usage()) and (
|
||||
not self._report_torch_xla_usage()
|
||||
):
|
||||
if not self._report_tpu_inference_usage():
|
||||
logger.exception("Failed to collect TPU information")
|
||||
self.provider = _detect_cloud_provider()
|
||||
self.architecture = platform.machine()
|
||||
|
||||
@ -1,436 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv, next_power_of_2
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# TPU requires the head size to be a multiple of 128.
|
||||
TPU_HEAD_SIZE_ALIGNMENT = 128
|
||||
|
||||
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
|
||||
# from to fp32 directly. That's why it has a dtype mapping different from GPU
|
||||
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
"fp8": torch.float8_e4m3fn,
|
||||
"fp8_e4m3": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
"int8": torch.int8,
|
||||
"uint8": torch.uint8,
|
||||
}
|
||||
|
||||
try:
|
||||
import tpu_inference # noqa: F401
|
||||
except ImportError:
|
||||
# Lazy import torch_xla
|
||||
import torch_xla.core.xla_builder as xb
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
from torch.library import impl
|
||||
from torch_xla._internal.jax_workarounds import requires_jax
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB
|
||||
|
||||
@requires_jax
|
||||
def kv_cache_update_op_impl(
|
||||
kv: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int,
|
||||
):
|
||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||
|
||||
new_kv_cache = xb.call_jax(
|
||||
kv_cache_update,
|
||||
(kv, slot_mapping, kv_cache, num_kv_update_slices),
|
||||
{"page_size": page_size, "num_slices_per_block": num_slices_per_block},
|
||||
)
|
||||
return new_kv_cache
|
||||
|
||||
XLA_LIB.define(
|
||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping,"
|
||||
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size,"
|
||||
"int num_slices_per_block)"
|
||||
"-> Tensor",
|
||||
)
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||
def kv_cache_update_op_xla(
|
||||
kv: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int,
|
||||
) -> torch.Tensor:
|
||||
new_kv_cache = kv_cache_update_op_impl(
|
||||
kv,
|
||||
slot_mapping,
|
||||
kv_cache,
|
||||
num_kv_update_slices,
|
||||
page_size,
|
||||
num_slices_per_block,
|
||||
)
|
||||
return new_kv_cache
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||
def kv_cache_update_op_non_xla(
|
||||
kv: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int,
|
||||
) -> torch.Tensor:
|
||||
return kv_cache
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
padded_head_size = (
|
||||
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
)
|
||||
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
|
||||
# block_tables within the PallasMetadata constitute almost the entire SMEM
|
||||
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
|
||||
# we simply make sure that the size is smaller than half of SMEM capacity.
|
||||
@staticmethod
|
||||
def get_min_page_size(vllm_config: VllmConfig) -> int:
|
||||
max_num_page_per_req = (
|
||||
1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4
|
||||
)
|
||||
min_page_size = cdiv(
|
||||
vllm_config.model_config.max_model_len, max_num_page_per_req
|
||||
)
|
||||
min_page_size = 1 << (min_page_size - 1).bit_length()
|
||||
return min_page_size
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_seqs(model_len: int, page_size: int) -> int:
|
||||
num_page_per_req = cdiv(model_len, page_size)
|
||||
return 1024 * 1024 // 2 // num_page_per_req // 4
|
||||
|
||||
# TPU has limited SREGs (scalar registers), if page_size is too small, we
|
||||
# can spill SREGs easily which leads to bad performance. The strategy we
|
||||
# apply here is trying to split max-model-len to 16 pages which make the
|
||||
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
|
||||
@staticmethod
|
||||
def get_page_size(vllm_config: VllmConfig) -> int:
|
||||
# TODO: This is a temporary fix for vmem OOM.
|
||||
# For long model length, we use 16 page-size to avoid too much
|
||||
# VMEM spill. A more robust solution should be implemented to
|
||||
# handle VREG spills.
|
||||
if vllm_config.model_config.max_model_len > 8192:
|
||||
return 16
|
||||
page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16
|
||||
if page_size <= 16:
|
||||
return 16
|
||||
if page_size >= 256:
|
||||
return 256
|
||||
return page_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Used in the PallasAttentionBackendImpl
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
context_lens: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
num_seqs: torch.Tensor
|
||||
num_kv_update_slices: torch.Tensor
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl"
|
||||
)
|
||||
|
||||
self.kv_cache_quantized_dtype = None
|
||||
if kv_cache_dtype != "auto":
|
||||
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
|
||||
kv_cache_dtype.lower().strip()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: PallasMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache: shape =
|
||||
[num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionBackendImpl"
|
||||
)
|
||||
|
||||
# For determine_available_memory case.
|
||||
if kv_cache.numel() == 0:
|
||||
if output is None:
|
||||
output = torch.ones_like(query)
|
||||
return output
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
padded_head_size = (
|
||||
cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
)
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, padded_head_size - self.head_size), value=0.0
|
||||
)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, padded_head_size - self.head_size), value=0.0
|
||||
)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, padded_head_size - self.head_size), value=0.0
|
||||
)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
|
||||
# Write input keys and values to the KV cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
attn_metadata.num_slices_per_kv_cache_update_block,
|
||||
attn_metadata.num_kv_update_slices,
|
||||
self.kv_cache_quantized_dtype,
|
||||
layer._k_scale_float,
|
||||
layer._v_scale_float,
|
||||
)
|
||||
|
||||
if self.kv_cache_quantized_dtype is not None and (
|
||||
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0
|
||||
):
|
||||
raise ValueError("k_scale_float and v_scale_float must be non-zero")
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.num_seqs,
|
||||
# By default, the system utilizes optimized block size and
|
||||
# vmem_limit_bytes parameters from the kernel repository. However,
|
||||
# these can be manually adjusted for debugging if necessary.
|
||||
num_kv_pages_per_block=None,
|
||||
num_queries_per_block=None,
|
||||
vmem_limit_bytes=None,
|
||||
use_kernel=True,
|
||||
sm_scale=self.scale,
|
||||
sliding_window=self.sliding_window,
|
||||
soft_cap=self.logits_soft_cap,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
)
|
||||
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
output = output[:, :, : self.head_size]
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_slices_per_kv_cache_update_block: int,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
kv_cache_quantized_dtype: torch.dtype | None = None,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
) -> None:
|
||||
"""Write the key and values to the KV cache.
|
||||
|
||||
Args:
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
"""
|
||||
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
|
||||
if kv_cache_quantized_dtype is not None:
|
||||
dtype_info = torch.finfo(kv_cache_quantized_dtype)
|
||||
key = key.to(torch.float32) / k_scale
|
||||
# NOTE: clamp is added here to avoid out of range of quantized dtype
|
||||
key = torch.clamp(key, dtype_info.min, dtype_info.max)
|
||||
key = key.to(kv_cache_quantized_dtype)
|
||||
value = value.to(torch.float32) / v_scale
|
||||
value = torch.clamp(value, dtype_info.min, dtype_info.max)
|
||||
value = value.to(kv_cache_quantized_dtype)
|
||||
|
||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size)
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
||||
|
||||
kv_cache = kv_cache.flatten(0, 1)
|
||||
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
||||
kv,
|
||||
slot_mapping,
|
||||
kv_cache,
|
||||
num_kv_update_slices,
|
||||
page_size,
|
||||
num_slices_per_kv_cache_update_block,
|
||||
)
|
||||
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
||||
kv_cache.copy_(new_kv_cache)
|
||||
|
||||
|
||||
# We can move this function to a common utils file if it's also useful for other
|
||||
# hardware.
|
||||
def dtype_bits(dtype: torch.dtype):
|
||||
if dtype.is_floating_point:
|
||||
try:
|
||||
return torch.finfo(dtype).bits
|
||||
except TypeError:
|
||||
pass
|
||||
elif dtype.is_complex:
|
||||
if dtype is torch.complex32:
|
||||
return 32
|
||||
elif dtype is torch.complex64:
|
||||
return 64
|
||||
elif dtype is torch.complex128:
|
||||
return 128
|
||||
else:
|
||||
try:
|
||||
return torch.iinfo(dtype).bits
|
||||
# torch.iinfo cannot support int4, int2, bits8...
|
||||
except TypeError:
|
||||
pass
|
||||
str_dtype = str(dtype)
|
||||
# support torch.int4, torch.int5, torch.uint5...
|
||||
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
|
||||
return int(str_dtype[-1])
|
||||
raise TypeError(f"Getting the bit width of {dtype} is not supported")
|
||||
|
||||
|
||||
def get_dtype_packing(dtype):
|
||||
bits = dtype_bits(dtype)
|
||||
if 32 % bits != 0:
|
||||
raise ValueError(
|
||||
f"The bit width must be divisible by 32, but got bits={bits}, "
|
||||
"dtype={dtype}"
|
||||
)
|
||||
return 32 // bits
|
||||
|
||||
|
||||
def get_page_size_bytes(
|
||||
block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype
|
||||
) -> int:
|
||||
"""Returns the size in bytes of one page of the KV cache."""
|
||||
padded_head_size = (
|
||||
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
)
|
||||
num_combined_kv_heads = num_kv_heads * 2
|
||||
|
||||
# NOTE: for the implicit padding in XLA
|
||||
packing = get_dtype_packing(kv_cache_dtype)
|
||||
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
|
||||
|
||||
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
|
||||
return (
|
||||
block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@ -2,350 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A TPU worker class."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
)
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_initialized,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R")
|
||||
|
||||
if not USE_TPU_INFERENCE:
|
||||
logger.info("tpu_inference not found, using vLLM's TPUWorker.")
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||
|
||||
|
||||
class TPUWorker:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.use_spmd = envs.VLLM_XLA_USE_SPMD
|
||||
self.original_parallel_config = None
|
||||
if self.use_spmd:
|
||||
# Under SPMD mode, distributed env is initialized as if there is
|
||||
# only one worker/device.
|
||||
self.original_parallel_config = self.parallel_config
|
||||
self.parallel_config.tensor_parallel_size = 1
|
||||
self.parallel_config.pipeline_parallel_size = 1
|
||||
self.parallel_config.world_size = 1
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.cache_dtype = self.model_config.dtype
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils.import_utils import init_cached_hf_modules
|
||||
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Delay profiler initialization to the start of the profiling.
|
||||
# This is because in vLLM V1, MP runtime is initialized before the
|
||||
# TPU Worker is initialized. The profiler server needs to start after
|
||||
# MP runtime is initialized.
|
||||
self.profiler = None
|
||||
self.profile_dir = None
|
||||
if vllm_config.profiler_config.profiler == "torch" and self.rank < 1:
|
||||
# For TPU, we can only have 1 active profiler session for 1 profiler
|
||||
# server. So we only profile on rank0.
|
||||
self.profile_dir = vllm_config.profiler_config.torch_profiler_dir
|
||||
logger.info(
|
||||
"Profiling enabled. Traces will be saved to: %s", self.profile_dir
|
||||
)
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
# ring, the xla tpu compiler flag
|
||||
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
|
||||
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
||||
os.environ["LIBTPU_INIT_ARGS"] = (
|
||||
os.environ.get("LIBTPU_INIT_ARGS", "")
|
||||
+ " --xla_tpu_force_1d_allreduce_at_chunk_count=1"
|
||||
" --xla_jf_conv_input_fusion=False"
|
||||
)
|
||||
# --xla_jf_conv_input_fusion=False is used to improve the perf of
|
||||
# quantized matmul.
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment.
|
||||
self._init_tpu_worker_distributed_environment(
|
||||
self.vllm_config, self.rank, self.distributed_init_method, self.local_rank
|
||||
)
|
||||
|
||||
# Device initialization should happen after initializing
|
||||
# the distributed runtime.
|
||||
self.device = xm.xla_device()
|
||||
self.device_config.device = self.device
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
# TODO (NickLucche) On gsm we compile 80+ graphs.
|
||||
# Re-evaluate limit, with MM we may get close to this limit.
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
# Use persistent cache to avoid XLA recompilation.
|
||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||
# can have slightly different XLA graphs.
|
||||
world_size = self.parallel_config.world_size
|
||||
rank = xr.global_ordinal()
|
||||
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
|
||||
# Consequently, changes in optimization flags, which affect compilation
|
||||
# results, don't change the cache key. This can result in the wrong
|
||||
# compilation being used. To prevent this, disabling the XLA compilation
|
||||
# cache during development is recommended.We can disable it by
|
||||
# `export VLLM_XLA_CACHE_PATH=`
|
||||
if envs.VLLM_XLA_CACHE_PATH:
|
||||
per_rank_path = os.path.join(
|
||||
envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}"
|
||||
)
|
||||
xr.initialize_cache(per_rank_path, readonly=False)
|
||||
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
self.model_runner = TPUModelRunner(
|
||||
self.vllm_config, self.device, self.original_parallel_config
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(self.vllm_config)
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
if isinstance(layer_spec, AttentionSpec):
|
||||
dtype = layer_spec.dtype
|
||||
|
||||
# Use an empty tensor instead of `None` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value `None`.
|
||||
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported KV cache spec '{type(layer_spec)}'"
|
||||
)
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
runner_kv_caches,
|
||||
)
|
||||
|
||||
# `max_num_tokens >= max_num_batched_tokens` due to padding.
|
||||
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
|
||||
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
||||
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
# During the profiling run, the model runs without KV cache. After
|
||||
# the profiling run, the model always runs with KV cache. Here we clear
|
||||
# the dynamo cache and cached bytecode to ensure the model always has
|
||||
# one compiled bytecode. Having one FX graph/cached bytecode per
|
||||
# compiled model is required for `support_torch_compile` decorator to
|
||||
# skip dynamo guard.
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
self.model_runner.reset_dynamo_cache()
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
if self.use_spmd:
|
||||
# This is a workaround for the TPU SPMD mode. The get_memory_info
|
||||
# API doesn't work with SPMD mode in PyTorch/XLA.
|
||||
# TODO: use xm.get_memory_info for SPMD once it's supported in
|
||||
# PyTorch/XLA.
|
||||
import tpu_info
|
||||
|
||||
chip_type, _ = tpu_info.device.get_local_chips()
|
||||
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
|
||||
total_memory_size = device_usage[0].total_memory
|
||||
current_mem = device_usage[0].memory_usage
|
||||
else:
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
current_mem = m["bytes_used"]
|
||||
# Ideally we would use profiled = m["peak_bytes_used"] to
|
||||
# get weights + activations. But there is memory used during
|
||||
# compilation / weight loading that impacts the peak and
|
||||
# there is no way to reset peak memory in XLA, So we
|
||||
# use the heuristic of 2% of weights.
|
||||
profiled = current_mem * 1.02
|
||||
|
||||
# Calculate the TPU KV cache size based on profiling.
|
||||
usable_memory_size = int(
|
||||
total_memory_size * self.cache_config.gpu_memory_utilization
|
||||
)
|
||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
head_size = self.model_config.get_head_size()
|
||||
if head_size > 0:
|
||||
padded_head_size = (
|
||||
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
)
|
||||
if padded_head_size != head_size:
|
||||
logger.warning_once("head size is padded to %d", padded_head_size)
|
||||
# We adjust the usable memory size for the KV cache to prevent OOM
|
||||
# errors, even after padding the head_size.
|
||||
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
|
||||
return int(tpu_kv_cache_bytes)
|
||||
|
||||
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
|
||||
return self.model_runner.sample_tokens(grammar_output)
|
||||
|
||||
def execute_model(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> ModelRunnerOutput | None:
|
||||
return self.model_runner.execute_model(scheduler_output)
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.rank < 1:
|
||||
if self.profile_dir is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
if self.profiler is None:
|
||||
self.profiler = xp.start_server(9012)
|
||||
xp.start_trace(self.profile_dir)
|
||||
else:
|
||||
xp.stop_trace()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model_runner.load_model()
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
self.model_runner.reload_weights()
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
self.model_runner.reset_mm_cache()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def _init_tpu_worker_distributed_environment(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
distributed_init_method: str | None = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
if self.use_spmd:
|
||||
xr.use_spmd()
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
parallel_config = vllm_config.parallel_config
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=distributed_init_method or "env://",
|
||||
backend=current_platform.dist_backend,
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
|
||||
)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.model_runner.ensure_kv_transfer_shutdown()
|
||||
|
||||
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
|
||||
"""Apply a function on the model inside this worker."""
|
||||
return fn(self.get_model())
|
||||
|
||||
|
||||
# TODO(weiyulin) Remove this file after adding an official way to use hardware plugin
|
||||
if USE_TPU_INFERENCE:
|
||||
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user