mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 15:07:53 +08:00
Merge 9aaed80cc85f65438f859bdd19fe90d6b712be5c into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
737b3079ad
@ -92,7 +92,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
|
|||||||
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
|
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
|
||||||
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
|
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
|
||||||
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
||||||
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
|
|
||||||
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
|
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
|
||||||
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
|
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
|
||||||
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
|
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
|
||||||
|
|||||||
@ -1,139 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import pytest
|
|
||||||
from torch_xla._internal import tpu
|
|
||||||
|
|
||||||
import vllm
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
|
|
||||||
# This file contains tests to ensure that LoRA works correctly on the TPU
|
|
||||||
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
|
|
||||||
# for this. The adapters are:
|
|
||||||
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
|
|
||||||
# from 1 to 4.
|
|
||||||
|
|
||||||
# These adapters are trained using a standard huggingface peft training script,
|
|
||||||
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
|
|
||||||
# 100 training iterations with a training batch size of 100.
|
|
||||||
|
|
||||||
|
|
||||||
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
|
|
||||||
return vllm.LLM(
|
|
||||||
model="Qwen/Qwen2.5-3B-Instruct",
|
|
||||||
max_model_len=256,
|
|
||||||
max_num_seqs=8,
|
|
||||||
tensor_parallel_size=tp,
|
|
||||||
enable_lora=True,
|
|
||||||
max_loras=num_loras,
|
|
||||||
max_lora_rank=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TPU_TENSOR_PARALLEL_SIZES = (
|
|
||||||
[1, tpu.num_available_chips()] if tpu.num_available_chips() > 1 else [1]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
|
|
||||||
def test_single_lora(tp: int):
|
|
||||||
"""
|
|
||||||
This test ensures we can run a single LoRA adapter on the TPU backend.
|
|
||||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
|
|
||||||
will force Qwen2.5-3B-Instruct to claim 1+1=1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
llm = setup_vllm(1, tp)
|
|
||||||
|
|
||||||
prompt = "What is 1+1? \n"
|
|
||||||
|
|
||||||
lora_request = LoRARequest(
|
|
||||||
"lora_adapter_1",
|
|
||||||
1,
|
|
||||||
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter",
|
|
||||||
)
|
|
||||||
output = (
|
|
||||||
llm.generate(
|
|
||||||
prompt,
|
|
||||||
sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
|
|
||||||
lora_request=lora_request,
|
|
||||||
)[0]
|
|
||||||
.outputs[0]
|
|
||||||
.text
|
|
||||||
)
|
|
||||||
|
|
||||||
answer = output.strip()[0]
|
|
||||||
|
|
||||||
assert answer.isdigit()
|
|
||||||
assert int(answer) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
|
|
||||||
def test_lora_hotswapping(tp: int):
|
|
||||||
"""
|
|
||||||
This test ensures we can run multiple LoRA adapters on the TPU backend, even
|
|
||||||
if we only have space to store 1.
|
|
||||||
|
|
||||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
|
||||||
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
|
||||||
"""
|
|
||||||
|
|
||||||
lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
|
||||||
lora_requests = [
|
|
||||||
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
|
||||||
for i in range(1, 5)
|
|
||||||
]
|
|
||||||
|
|
||||||
llm = setup_vllm(1, tp)
|
|
||||||
|
|
||||||
prompt = "What is 1+1? \n"
|
|
||||||
|
|
||||||
for i, req in enumerate(lora_requests):
|
|
||||||
output = (
|
|
||||||
llm.generate(
|
|
||||||
prompt,
|
|
||||||
sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
|
|
||||||
lora_request=req,
|
|
||||||
)[0]
|
|
||||||
.outputs[0]
|
|
||||||
.text
|
|
||||||
)
|
|
||||||
answer = output.strip()[0]
|
|
||||||
|
|
||||||
assert answer.isdigit()
|
|
||||||
assert int(answer) == i + 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
|
|
||||||
def test_multi_lora(tp: int):
|
|
||||||
"""
|
|
||||||
This test ensures we can run multiple LoRA adapters on the TPU backend, when
|
|
||||||
we have enough space to store all of them.
|
|
||||||
|
|
||||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
|
||||||
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
|
||||||
"""
|
|
||||||
lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
|
||||||
lora_requests = [
|
|
||||||
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
|
||||||
for i in range(1, 5)
|
|
||||||
]
|
|
||||||
|
|
||||||
llm = setup_vllm(4, tp)
|
|
||||||
|
|
||||||
prompt = "What is 1+1? \n"
|
|
||||||
|
|
||||||
for i, req in enumerate(lora_requests):
|
|
||||||
output = (
|
|
||||||
llm.generate(
|
|
||||||
prompt,
|
|
||||||
sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
|
|
||||||
lora_request=req,
|
|
||||||
)[0]
|
|
||||||
.outputs[0]
|
|
||||||
.text
|
|
||||||
)
|
|
||||||
|
|
||||||
answer = output.strip()[0]
|
|
||||||
|
|
||||||
assert answer.isdigit()
|
|
||||||
assert int(output.strip()[0]) == i + 1
|
|
||||||
@ -1,86 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import depyf
|
|
||||||
|
|
||||||
|
|
||||||
def test_tpu_compilation():
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
with depyf.prepare_debug(temp_dir):
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
|
|
||||||
prompts = [
|
|
||||||
"A robot may not injure a human being",
|
|
||||||
"It is only with the heart that one can see rightly;",
|
|
||||||
"The greatest glory in living lies not in never falling,",
|
|
||||||
]
|
|
||||||
answers = [
|
|
||||||
" or, through inaction",
|
|
||||||
" what is essential ",
|
|
||||||
" but in rising ",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
|
|
||||||
N = 1
|
|
||||||
sampling_params = SamplingParams(temperature=0.7, top_p=1.0, n=N, max_tokens=16)
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
model="Qwen/Qwen2-1.5B-Instruct",
|
|
||||||
max_num_batched_tokens=256,
|
|
||||||
max_model_len=256,
|
|
||||||
max_num_seqs=32,
|
|
||||||
enforce_eager=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
|
||||||
for output, answer in zip(outputs, answers):
|
|
||||||
prompt = output.prompt
|
|
||||||
generated_text = output.outputs[0].text
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
||||||
assert generated_text.startswith(answer)
|
|
||||||
|
|
||||||
compiled_codes = sorted(
|
|
||||||
glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py"))
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, compiled_code in enumerate(compiled_codes):
|
|
||||||
print("{} file: {}".format(i + 1, compiled_code))
|
|
||||||
|
|
||||||
# We should only trigger Dynamo compilation 2 times:
|
|
||||||
# 1. Forward pass without kv_caches
|
|
||||||
# 2. Forward pass with kv_caches
|
|
||||||
# Check we have 2 compiled codes
|
|
||||||
assert len(compiled_codes) == 2
|
|
||||||
|
|
||||||
kv_cache_prefix = "kv_cache"
|
|
||||||
attn_prefix = "ragged_paged_attention"
|
|
||||||
|
|
||||||
def extract_compiled_index(s):
|
|
||||||
parts = s.replace(".", "_").split("_")
|
|
||||||
numbers = [int(part) for part in parts if part.isdigit()]
|
|
||||||
return numbers[0]
|
|
||||||
|
|
||||||
# Check all the compilations are as expected. The dump files include the
|
|
||||||
# captured graph for the forward function of the nn.Module.
|
|
||||||
compiled_fns = sorted(
|
|
||||||
glob.glob(os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")),
|
|
||||||
key=lambda s: extract_compiled_index(s),
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, compiled_fn in enumerate(compiled_fns):
|
|
||||||
print("{} file: {}".format(i + 1, compiled_fn))
|
|
||||||
|
|
||||||
# The first compilation should not have any kv_caches
|
|
||||||
with open(compiled_fns[0]) as f:
|
|
||||||
content = f.read()
|
|
||||||
assert kv_cache_prefix not in content
|
|
||||||
|
|
||||||
# The second compilation should have kv_caches and the
|
|
||||||
# ragged_paged_attention
|
|
||||||
with open(compiled_fns[1]) as f:
|
|
||||||
content = f.read()
|
|
||||||
assert kv_cache_prefix in content and attn_prefix in content
|
|
||||||
@ -1,34 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.config import CompilationMode
|
|
||||||
|
|
||||||
from ..utils import compare_two_settings
|
|
||||||
|
|
||||||
# --enforce-eager on TPU causes graph compilation
|
|
||||||
# this times out default Health Check in the MQLLMEngine,
|
|
||||||
# so we set the timeout here to 30s
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setenv("VLLM_RPC_TIMEOUT", "30000")
|
|
||||||
compare_two_settings(
|
|
||||||
"Qwen/Qwen2.5-1.5B-Instruct",
|
|
||||||
arg1=[
|
|
||||||
"--max-model-len=256",
|
|
||||||
"--max-num-seqs=32",
|
|
||||||
"--enforce-eager",
|
|
||||||
f"-O{CompilationMode.DYNAMO_TRACE_ONCE}",
|
|
||||||
],
|
|
||||||
arg2=[
|
|
||||||
"--max-model-len=256",
|
|
||||||
"--max-num-seqs=32",
|
|
||||||
"--enforce-eager",
|
|
||||||
f"-O{CompilationMode.STOCK_TORCH_COMPILE}",
|
|
||||||
],
|
|
||||||
env1={},
|
|
||||||
env2={},
|
|
||||||
)
|
|
||||||
@ -1,88 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""Tests for the Pallas MOE implementation.
|
|
||||||
|
|
||||||
Run `pytest tests/kernels/moe/test_moe_pallas.py`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch_xla
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe as pallas_moe
|
|
||||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
|
||||||
fused_moe as torch_moe,
|
|
||||||
)
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
if not current_platform.is_tpu():
|
|
||||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64]
|
|
||||||
EP_SIZE = [1]
|
|
||||||
TOP_KS = [2, 6]
|
|
||||||
|
|
||||||
|
|
||||||
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
|
|
||||||
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
|
|
||||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
|
||||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
|
||||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
||||||
def test_pallas_moe(
|
|
||||||
m: int,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
e: int,
|
|
||||||
topk: int,
|
|
||||||
ep_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
with torch.device(xm.xla_device()):
|
|
||||||
a = torch.randn((m, k), dtype=dtype) / 10
|
|
||||||
w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
|
|
||||||
w2 = torch.randn((e, k, n), dtype=dtype) / 10
|
|
||||||
|
|
||||||
score = torch.randn((m, e), dtype=dtype)
|
|
||||||
|
|
||||||
# TODO: Support ep
|
|
||||||
if ep_size > 1:
|
|
||||||
pytest.skip("No support for ep_size > 1 yet")
|
|
||||||
else:
|
|
||||||
e_map = None
|
|
||||||
|
|
||||||
# Run both implementations
|
|
||||||
torch_output = torch_moe(
|
|
||||||
hidden_states=a,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
gating_output=score,
|
|
||||||
topk=topk,
|
|
||||||
global_num_experts=e,
|
|
||||||
expert_map=e_map,
|
|
||||||
renormalize=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
pallas_output = pallas_moe(
|
|
||||||
hidden_states=a,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
gating_output=score,
|
|
||||||
topk=topk,
|
|
||||||
global_num_experts=e,
|
|
||||||
expert_map=e_map,
|
|
||||||
renormalize=False,
|
|
||||||
)
|
|
||||||
torch_xla.sync(wait=False)
|
|
||||||
|
|
||||||
# Compare outputs
|
|
||||||
torch.testing.assert_close(
|
|
||||||
pallas_output.cpu(),
|
|
||||||
torch_output.cpu(),
|
|
||||||
atol=2e-2,
|
|
||||||
rtol=0,
|
|
||||||
)
|
|
||||||
@ -1,52 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import lm_eval
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
TASK = "gsm8k"
|
|
||||||
FILTER = "exact_match,strict-match"
|
|
||||||
RTOL = 0.03
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GSM8KAccuracyTestConfig:
|
|
||||||
model_name: str
|
|
||||||
expected_value: float
|
|
||||||
|
|
||||||
def get_model_args(self) -> str:
|
|
||||||
return f"pretrained={self.model_name},max_model_len=4096,max_num_seqs=32"
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Accuracy scores measured on GPUs.
|
|
||||||
ACCURACY_CONFIGS = [
|
|
||||||
GSM8KAccuracyTestConfig(
|
|
||||||
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
|
|
||||||
expected_value=0.76,
|
|
||||||
), # no bias
|
|
||||||
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
|
|
||||||
# so only one of these tests can run in a single call to pytest. As
|
|
||||||
# a follow-up, move this into the LM-EVAL section of the CI.
|
|
||||||
# GSM8KAccuracyTestConfig(
|
|
||||||
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
|
|
||||||
# expected_value=0.66), # bias in QKV layers
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
|
|
||||||
def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
|
|
||||||
results = lm_eval.simple_evaluate(
|
|
||||||
model="vllm",
|
|
||||||
model_args=config.get_model_args(),
|
|
||||||
tasks="gsm8k",
|
|
||||||
batch_size="auto",
|
|
||||||
)
|
|
||||||
|
|
||||||
EXPECTED_VALUE = config.expected_value
|
|
||||||
measured_value = results["results"][TASK][FILTER]
|
|
||||||
assert (
|
|
||||||
measured_value - RTOL < EXPECTED_VALUE
|
|
||||||
and measured_value + RTOL > EXPECTED_VALUE
|
|
||||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
|
||||||
@ -1,177 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""A basic correctness check for TPUs
|
|
||||||
|
|
||||||
Run `pytest tests/v1/tpu/test_basic.py`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from torch_xla._internal import tpu
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from tests.conftest import VllmRunner
|
|
||||||
else:
|
|
||||||
VllmRunner = object
|
|
||||||
|
|
||||||
MODELS = [
|
|
||||||
"Qwen/Qwen2.5-1.5B-Instruct",
|
|
||||||
# TODO: Enable this model when fixed.
|
|
||||||
# "Qwen/Qwen1.5-MoE-A2.7B",
|
|
||||||
# TODO: Enable this models with v6e
|
|
||||||
# "Qwen/Qwen2-7B-Instruct",
|
|
||||||
# "meta-llama/Llama-3.1-8B",
|
|
||||||
]
|
|
||||||
|
|
||||||
TENSOR_PARALLEL_SIZES = [1]
|
|
||||||
MAX_NUM_REQS = [16, 1024]
|
|
||||||
|
|
||||||
# TODO: Enable when CI/CD will have a multi-tpu instance
|
|
||||||
# TENSOR_PARALLEL_SIZES = [1, 4]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("max_tokens", [5])
|
|
||||||
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
|
|
||||||
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
|
|
||||||
def test_basic(
|
|
||||||
vllm_runner: type[VllmRunner],
|
|
||||||
model: str,
|
|
||||||
max_tokens: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
max_num_seqs: int,
|
|
||||||
) -> None:
|
|
||||||
prompt = (
|
|
||||||
"The next numbers of the sequence "
|
|
||||||
+ ", ".join(str(i) for i in range(1024))
|
|
||||||
+ " are:"
|
|
||||||
)
|
|
||||||
example_prompts = [prompt]
|
|
||||||
|
|
||||||
with vllm_runner(
|
|
||||||
model,
|
|
||||||
# Note: max_num_batched_tokens == 1024 is needed here to
|
|
||||||
# actually test chunked prompt
|
|
||||||
max_num_batched_tokens=1024,
|
|
||||||
max_model_len=8192,
|
|
||||||
gpu_memory_utilization=0.7,
|
|
||||||
max_num_seqs=max_num_seqs,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
) as vllm_model:
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
||||||
output = vllm_outputs[0][1]
|
|
||||||
|
|
||||||
assert "1024" in output or "0, 1" in output
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Temporarily disabled due to timeout")
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("max_tokens", [8])
|
|
||||||
@pytest.mark.parametrize("max_num_seqs", [16])
|
|
||||||
def test_phi3(
|
|
||||||
vllm_runner: type[VllmRunner],
|
|
||||||
max_tokens: int,
|
|
||||||
max_num_seqs: int,
|
|
||||||
) -> None:
|
|
||||||
prompts = [
|
|
||||||
"A robot may not injure a human being",
|
|
||||||
"It is only with the heart that one can see rightly;",
|
|
||||||
"The greatest glory in living lies not in never falling,",
|
|
||||||
]
|
|
||||||
answers = [
|
|
||||||
" or, by violating privacy",
|
|
||||||
" what is essential is love.",
|
|
||||||
" but in rising every time we fall.",
|
|
||||||
]
|
|
||||||
# test head dim = 96
|
|
||||||
model = "microsoft/Phi-3-mini-128k-instruct"
|
|
||||||
|
|
||||||
with vllm_runner(
|
|
||||||
model, max_num_batched_tokens=256, max_num_seqs=max_num_seqs
|
|
||||||
) as vllm_model:
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
|
|
||||||
# vllm_outputs is a list of tuples whose first element is the token id
|
|
||||||
# and the second element is the output (including the prompt).
|
|
||||||
for output, answer in zip(vllm_outputs, answers):
|
|
||||||
generated_text = output[1]
|
|
||||||
assert answer in generated_text
|
|
||||||
|
|
||||||
|
|
||||||
TP_SIZE_8 = 8
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
tpu.num_available_chips() < TP_SIZE_8,
|
|
||||||
reason=f"This test requires {TP_SIZE_8} TPU chips.",
|
|
||||||
)
|
|
||||||
def test_gemma3_27b_with_text_input_and_tp(
|
|
||||||
vllm_runner: type[VllmRunner],
|
|
||||||
) -> None:
|
|
||||||
model = "google/gemma-3-27b-it"
|
|
||||||
max_tokens = 16
|
|
||||||
tensor_parallel_size = TP_SIZE_8
|
|
||||||
max_num_seqs = 4
|
|
||||||
prompts = [
|
|
||||||
"A robot may not injure a human being",
|
|
||||||
"It is only with the heart that one can see rightly;",
|
|
||||||
"The greatest glory in living lies not in never falling,",
|
|
||||||
]
|
|
||||||
answers = [
|
|
||||||
" or, through inaction, allow a human being to come to harm.",
|
|
||||||
" what is essential is invisible to the eye.",
|
|
||||||
" but in rising every time we fall.",
|
|
||||||
]
|
|
||||||
|
|
||||||
with vllm_runner(
|
|
||||||
model,
|
|
||||||
max_num_batched_tokens=256,
|
|
||||||
max_num_seqs=max_num_seqs,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
) as vllm_model:
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
|
|
||||||
# vllm_outputs is a list of tuples whose first element is the token id
|
|
||||||
# and the second element is the output (including the prompt).
|
|
||||||
for output, answer in zip(vllm_outputs, answers):
|
|
||||||
generated_text = output[1]
|
|
||||||
assert answer in generated_text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
|
|
||||||
)
|
|
||||||
def test_w8a8_quantization(
|
|
||||||
vllm_runner: type[VllmRunner],
|
|
||||||
) -> None:
|
|
||||||
model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
|
|
||||||
max_tokens = 5
|
|
||||||
tensor_parallel_size = 1
|
|
||||||
max_num_seqs = 4
|
|
||||||
|
|
||||||
prompt = (
|
|
||||||
"The next numbers of the sequence "
|
|
||||||
+ ", ".join(str(i) for i in range(1024))
|
|
||||||
+ " are:"
|
|
||||||
)
|
|
||||||
example_prompts = [prompt]
|
|
||||||
|
|
||||||
with vllm_runner(
|
|
||||||
model,
|
|
||||||
max_num_batched_tokens=64,
|
|
||||||
max_model_len=4096,
|
|
||||||
gpu_memory_utilization=0.7,
|
|
||||||
max_num_seqs=max_num_seqs,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
) as vllm_model:
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
||||||
output = vllm_outputs[0][1]
|
|
||||||
|
|
||||||
assert "1024" in output or "0, 1" in output
|
|
||||||
@ -1,78 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch_xla
|
|
||||||
|
|
||||||
import vllm.v1.attention.backends.pallas # noqa: F401
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
|
|
||||||
@pytest.mark.parametrize("page_size", [32, 33])
|
|
||||||
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
|
|
||||||
@pytest.mark.parametrize("head_dim", [128, 256])
|
|
||||||
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
|
|
||||||
def test_kv_cache_update_kernel(
|
|
||||||
page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int
|
|
||||||
):
|
|
||||||
page_num = 1000
|
|
||||||
padded_num_tokens = 128
|
|
||||||
kv_cache_cpu = torch.zeros(
|
|
||||||
(page_num * page_size, combined_kv_head_num, head_dim),
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
device="cpu",
|
|
||||||
)
|
|
||||||
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
|
|
||||||
new_kv_cpu = torch.randn(
|
|
||||||
(padded_num_tokens, combined_kv_head_num, head_dim),
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
device="cpu",
|
|
||||||
)
|
|
||||||
new_kv_xla = new_kv_cpu.to(torch_xla.device())
|
|
||||||
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], dtype=np.int32)
|
|
||||||
num_kv_update_slices = len(slice_lens)
|
|
||||||
kv_cache_start_indices = np.array(
|
|
||||||
[
|
|
||||||
page_size * 2 - 7,
|
|
||||||
page_size * 2,
|
|
||||||
page_size * 3,
|
|
||||||
page_size * 4 + 6,
|
|
||||||
page_size * 5 + 7,
|
|
||||||
page_size * 6 + 8,
|
|
||||||
page_size * 15 + 3,
|
|
||||||
],
|
|
||||||
dtype=np.int32,
|
|
||||||
)
|
|
||||||
new_kv_cache_indices = np.concatenate(
|
|
||||||
[np.array([0], dtype=np.int32), np.cumsum(slice_lens[:-1])]
|
|
||||||
)
|
|
||||||
slot_mapping = np.stack(
|
|
||||||
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1
|
|
||||||
)
|
|
||||||
slot_mapping = np.transpose(slot_mapping)
|
|
||||||
slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", dtype=torch.int32)
|
|
||||||
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
|
|
||||||
num_kv_update_slices_xla = torch.tensor(
|
|
||||||
[num_kv_update_slices], device=torch_xla.device(), dtype=torch.int32
|
|
||||||
)
|
|
||||||
torch_xla.sync()
|
|
||||||
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
|
|
||||||
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
|
|
||||||
new_kv_xla,
|
|
||||||
slot_mapping_xla,
|
|
||||||
kv_cache_xla,
|
|
||||||
num_kv_update_slices_xla,
|
|
||||||
page_size,
|
|
||||||
num_slices_per_block,
|
|
||||||
)
|
|
||||||
kv_cache_xla.copy_(new_kv_cache_xla)
|
|
||||||
torch_xla.sync()
|
|
||||||
|
|
||||||
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, slice_lens):
|
|
||||||
kv_cache_cpu[ci : ci + sl, :, :] = new_kv_cpu[ni : ni + sl, :, :]
|
|
||||||
|
|
||||||
assert torch.allclose(kv_cache_xla.cpu(), kv_cache_cpu, atol=1e-4, rtol=1e-4)
|
|
||||||
@ -1,94 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""
|
|
||||||
Test:
|
|
||||||
|
|
||||||
* Tests for MMEncoderAttention layer
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch_xla
|
|
||||||
import torch_xla.core
|
|
||||||
import torch_xla.core.xla_model
|
|
||||||
|
|
||||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
|
||||||
from vllm.attention.selector import _cached_get_attn_backend
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def clear_cache():
|
|
||||||
"""Clear lru cache to ensure each test case runs without caching."""
|
|
||||||
_cached_get_attn_backend.cache_clear()
|
|
||||||
|
|
||||||
|
|
||||||
def ref_attention(
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Native implementation of scaled dot product attention without mask:
|
|
||||||
- query, key, value: [batch_size, seq_len, num_heads, head_size]
|
|
||||||
- attn_mask: [batch_size, seq_len, seq_len]
|
|
||||||
"""
|
|
||||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
|
||||||
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
|
|
||||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
|
||||||
out = torch.matmul(attn_weights, value).transpose(1, 2)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
BATCH_SIZES = [1, 16]
|
|
||||||
SEQ_LENS = [1]
|
|
||||||
NUM_HEADS = [1, 16]
|
|
||||||
NUM_KV_HEADS = [1]
|
|
||||||
HEAD_SIZES = [64, 80]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
|
|
||||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
||||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
||||||
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
|
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
||||||
@pytest.mark.parametrize("device", [torch_xla.core.xla_model.xla_device()])
|
|
||||||
def test_mha_attn_forward(
|
|
||||||
batch_size: int,
|
|
||||||
seq_len: int,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
device: str,
|
|
||||||
):
|
|
||||||
current_platform.seed_everything(0)
|
|
||||||
# These are expected to be f32
|
|
||||||
q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device)
|
|
||||||
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
|
|
||||||
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
|
|
||||||
scale = 1.0 / head_size**0.5
|
|
||||||
attn = MMEncoderAttention(
|
|
||||||
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
|
||||||
)
|
|
||||||
output = attn(q, k, v)
|
|
||||||
|
|
||||||
assert num_heads % num_kv_heads == 0
|
|
||||||
num_queries_per_kv = num_heads // num_kv_heads
|
|
||||||
|
|
||||||
q = q.reshape(batch_size, seq_len, num_heads, head_size)
|
|
||||||
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
|
||||||
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
|
||||||
if num_queries_per_kv > 1:
|
|
||||||
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
|
|
||||||
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
|
|
||||||
|
|
||||||
ref_output = ref_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
scale=scale,
|
|
||||||
).reshape(batch_size, seq_len, num_heads * head_size)
|
|
||||||
# torch_xla flash_attn kernel is less accurate but much faster
|
|
||||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-3)
|
|
||||||
@ -1,76 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.multimodal.utils import encode_image_url
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS
|
|
||||||
from ...utils import RemoteOpenAIServer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def url_encoded_image(local_asset_server) -> dict[str, str]:
|
|
||||||
return {
|
|
||||||
image_asset: encode_image_url(local_asset_server.get_image_asset(image_asset))
|
|
||||||
for image_asset in TEST_IMAGE_ASSETS
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
|
|
||||||
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
|
|
||||||
async def test_basic_vision(model_name: str, url_encoded_image: dict[str, str]):
|
|
||||||
pytest.skip("Skip this test until it's fixed.")
|
|
||||||
|
|
||||||
def whats_in_this_image_msg(url):
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "What's in this image?"},
|
|
||||||
{"type": "image_url", "image_url": {"url": url}},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
server_args = [
|
|
||||||
"--max-model-len",
|
|
||||||
"1024",
|
|
||||||
"--max-num-seqs",
|
|
||||||
"16",
|
|
||||||
"--gpu-memory-utilization",
|
|
||||||
"0.95",
|
|
||||||
"--trust-remote-code",
|
|
||||||
"--max-num-batched-tokens",
|
|
||||||
"576",
|
|
||||||
# NOTE: max-num-batched-tokens>=mm_item_size
|
|
||||||
"--disable_chunked_mm_input",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Server will pre-compile on first startup (takes a long time).
|
|
||||||
with RemoteOpenAIServer(
|
|
||||||
model_name, server_args, max_wait_seconds=600
|
|
||||||
) as remote_server:
|
|
||||||
client: openai.AsyncOpenAI = remote_server.get_async_client()
|
|
||||||
|
|
||||||
# Other requests now should be much faster
|
|
||||||
for image_url in TEST_IMAGE_ASSETS:
|
|
||||||
image_url = url_encoded_image[image_url]
|
|
||||||
chat_completion_from_url = await client.chat.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
messages=whats_in_this_image_msg(image_url),
|
|
||||||
max_completion_tokens=24,
|
|
||||||
temperature=0.0,
|
|
||||||
)
|
|
||||||
result = chat_completion_from_url
|
|
||||||
assert result
|
|
||||||
choice = result.choices[0]
|
|
||||||
assert choice.finish_reason == "length"
|
|
||||||
|
|
||||||
message = choice.message
|
|
||||||
message = result.choices[0].message
|
|
||||||
assert message.content is not None and len(message.content) >= 10
|
|
||||||
assert message.role == "assistant"
|
|
||||||
@ -1,100 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
from unittest.mock import ANY, patch
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
|
||||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackendImpl, PallasMetadata
|
|
||||||
|
|
||||||
|
|
||||||
def test_ragged_paged_attention():
|
|
||||||
# We verify that the kernel inputs such as sliding_window, etc. are passed
|
|
||||||
# in from the model correctly.
|
|
||||||
# The correctness of the paged attention kernel is tested in the kernel
|
|
||||||
# library.
|
|
||||||
num_heads = 4
|
|
||||||
head_size = 128
|
|
||||||
scale = 1.0
|
|
||||||
num_kv_heads = 4
|
|
||||||
sliding_window = 128
|
|
||||||
logits_soft_cap = 50.0
|
|
||||||
attn_impl = PallasAttentionBackendImpl(
|
|
||||||
num_heads=num_heads,
|
|
||||||
head_size=head_size,
|
|
||||||
scale=scale,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
alibi_slopes=None,
|
|
||||||
sliding_window=sliding_window,
|
|
||||||
kv_cache_dtype="auto",
|
|
||||||
logits_soft_cap=logits_soft_cap,
|
|
||||||
attn_type=AttentionType.DECODER,
|
|
||||||
)
|
|
||||||
|
|
||||||
class FakeAttentionLayer:
|
|
||||||
_q_scale_float: float
|
|
||||||
_k_scale_float: float
|
|
||||||
_v_scale_float: float
|
|
||||||
|
|
||||||
layer = FakeAttentionLayer()
|
|
||||||
layer._q_scale_float = 1.0
|
|
||||||
layer._k_scale_float = 1.0
|
|
||||||
layer._v_scale_float = 1.0
|
|
||||||
|
|
||||||
num_tokens = 16
|
|
||||||
num_blocks = 1024
|
|
||||||
block_size = 16
|
|
||||||
query = torch.zeros(num_tokens, num_heads * head_size)
|
|
||||||
key = torch.zeros(num_tokens, num_kv_heads * head_size)
|
|
||||||
value = torch.zeros(num_tokens, num_kv_heads * head_size)
|
|
||||||
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
|
|
||||||
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
|
|
||||||
max_num_reqs = 8
|
|
||||||
max_num_blocks_per_req = 8
|
|
||||||
num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32)
|
|
||||||
block_tables = torch.zeros(
|
|
||||||
(max_num_reqs, max_num_blocks_per_req), dtype=torch.int32
|
|
||||||
)
|
|
||||||
context_lens = torch.ones((max_num_reqs,), dtype=torch.int32)
|
|
||||||
query_lens = [1] * max_num_reqs
|
|
||||||
query_start_loc = torch.cumsum(
|
|
||||||
torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32
|
|
||||||
)
|
|
||||||
num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32)
|
|
||||||
attn_metadata = PallasMetadata(
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
block_tables=block_tables,
|
|
||||||
context_lens=context_lens,
|
|
||||||
query_start_loc=query_start_loc,
|
|
||||||
num_seqs=num_seqs,
|
|
||||||
num_kv_update_slices=num_kv_update_slices,
|
|
||||||
num_slices_per_kv_cache_update_block=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("torch.ops.xla.ragged_paged_attention") as mock_ragged_paged_attention:
|
|
||||||
attn_impl.forward(
|
|
||||||
layer=layer,
|
|
||||||
query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_ragged_paged_attention.assert_called_once_with(
|
|
||||||
ANY, # query
|
|
||||||
ANY, # kv_cache
|
|
||||||
ANY, # context_lens
|
|
||||||
ANY, # block_tables
|
|
||||||
ANY, # query_start_loc
|
|
||||||
ANY, # num_seqs
|
|
||||||
num_kv_pages_per_block=None,
|
|
||||||
num_queries_per_block=None,
|
|
||||||
vmem_limit_bytes=None,
|
|
||||||
use_kernel=True,
|
|
||||||
sm_scale=scale,
|
|
||||||
sliding_window=sliding_window,
|
|
||||||
soft_cap=logits_soft_cap,
|
|
||||||
k_scale=1.0,
|
|
||||||
v_scale=1.0,
|
|
||||||
)
|
|
||||||
@ -1,150 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""A basic performance regression test for TPUs
|
|
||||||
|
|
||||||
Run `pytest tests/v1/tpu/test_perf.py`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.tokenizers import get_tokenizer
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from tests.conftest import VllmRunner
|
|
||||||
else:
|
|
||||||
VllmRunner = object
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TestParams:
|
|
||||||
model: str
|
|
||||||
num_prompts: int
|
|
||||||
prefix_len: int
|
|
||||||
decode_len: int
|
|
||||||
expected_avg_time: float
|
|
||||||
err_tol: float
|
|
||||||
|
|
||||||
|
|
||||||
TEST_PARAMS = [
|
|
||||||
# TODO: Cannot run a series of tests because:
|
|
||||||
# RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed:
|
|
||||||
# open(/dev/vfio/0): Device or resource busy: Device or resource busy;
|
|
||||||
# Couldn't open iommu group /dev/vfio/0
|
|
||||||
# => Investigate
|
|
||||||
# TestParams(
|
|
||||||
# model="Qwen/Qwen2.5-1.5B-Instruct",
|
|
||||||
# num_prompts=1,
|
|
||||||
# prefix_len=10,
|
|
||||||
# decode_len=5,
|
|
||||||
# expected_avg_time=0.03,
|
|
||||||
# err_tol=0.01,
|
|
||||||
# ),
|
|
||||||
# TestParams(
|
|
||||||
# model="Qwen/Qwen2.5-1.5B-Instruct",
|
|
||||||
# num_prompts=10,
|
|
||||||
# prefix_len=100,
|
|
||||||
# decode_len=50,
|
|
||||||
# expected_avg_time=0.234,
|
|
||||||
# err_tol=0.020,
|
|
||||||
# ),
|
|
||||||
TestParams(
|
|
||||||
model="Qwen/Qwen2.5-1.5B-Instruct",
|
|
||||||
num_prompts=64,
|
|
||||||
prefix_len=500,
|
|
||||||
decode_len=50,
|
|
||||||
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
|
|
||||||
# tpu: v5lite (old vllm CI/CD)
|
|
||||||
# expected_avg_time=1.4,
|
|
||||||
# err_tol=0.30,
|
|
||||||
# (This is the active CI/CD instance)
|
|
||||||
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
|
|
||||||
# tpu: v6e (current vllm CI/CD)
|
|
||||||
expected_avg_time=1.7, # measured with VLLM_XLA_CACHE_PATH=
|
|
||||||
err_tol=0.20,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
NUM_WARMUPS = 5
|
|
||||||
NUM_RUNS = 10
|
|
||||||
|
|
||||||
MAX_MODEL_LEN = 1024
|
|
||||||
MAX_NUM_SEQS = 32
|
|
||||||
GPU_UTIL = 0.9
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not current_platform.is_tpu(),
|
|
||||||
reason="This is a basic performance test for TPU only",
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("params", TEST_PARAMS)
|
|
||||||
def test_perf(
|
|
||||||
vllm_runner: type[VllmRunner],
|
|
||||||
params: TestParams,
|
|
||||||
) -> None:
|
|
||||||
tokenizer = get_tokenizer(
|
|
||||||
params.model, tokenizer_mode="auto", trust_remote_code=True
|
|
||||||
)
|
|
||||||
|
|
||||||
prompts = []
|
|
||||||
for i in range(params.num_prompts):
|
|
||||||
prefix_token_ids = np.random.randint(
|
|
||||||
0, tokenizer.vocab_size, size=params.prefix_len
|
|
||||||
).tolist()
|
|
||||||
prompt = tokenizer.decode(prefix_token_ids)
|
|
||||||
prompts.append(prompt)
|
|
||||||
|
|
||||||
print(
|
|
||||||
"-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format(
|
|
||||||
len(prompts), params.prefix_len, params.decode_len
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
max_tokens=params.decode_len, temperature=1.0, min_p=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
with vllm_runner(
|
|
||||||
params.model,
|
|
||||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
|
||||||
gpu_memory_utilization=GPU_UTIL,
|
|
||||||
enforce_eager=False,
|
|
||||||
tensor_parallel_size=1,
|
|
||||||
) as vllm_model:
|
|
||||||
print(" -- Warmup / Compile")
|
|
||||||
for i in range(NUM_WARMUPS):
|
|
||||||
_ = vllm_model.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
print(" -- Benchmarking... ")
|
|
||||||
times = []
|
|
||||||
for i in range(NUM_RUNS):
|
|
||||||
start_time = time.time()
|
|
||||||
_ = vllm_model.generate(prompts, sampling_params)
|
|
||||||
times.append(time.time() - start_time)
|
|
||||||
|
|
||||||
avg_time = sum(times) / len(times)
|
|
||||||
|
|
||||||
print(" -- avg_time = {}".format(avg_time))
|
|
||||||
print(
|
|
||||||
" -- expected_avg_time = {} with err_tol = {}".format(
|
|
||||||
params.expected_avg_time, params.err_tol
|
|
||||||
)
|
|
||||||
)
|
|
||||||
diff = avg_time - params.expected_avg_time
|
|
||||||
ok = diff < params.err_tol
|
|
||||||
if diff < -params.err_tol:
|
|
||||||
print(
|
|
||||||
" !! WARNING !! Performance has improved by {}, "
|
|
||||||
"it may be necessary to fine-tune the "
|
|
||||||
"expected_avg_time = {}".format(-diff, params.expected_avg_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert ok, " !! ERROR !! Regression detected"
|
|
||||||
@ -1,105 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import random
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm import LLM
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
|
||||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
|
|
||||||
def test_sampler_different(model_name: str):
|
|
||||||
"""
|
|
||||||
Test significantly different sampling params to assert the model produces
|
|
||||||
different results.
|
|
||||||
"""
|
|
||||||
llm = LLM(
|
|
||||||
model_name,
|
|
||||||
enforce_eager=False,
|
|
||||||
max_num_seqs=1,
|
|
||||||
max_model_len=512,
|
|
||||||
max_num_batched_tokens=256,
|
|
||||||
)
|
|
||||||
prompts = ["Write a short story about a robot that dreams for the first time."]
|
|
||||||
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
|
|
||||||
output = llm.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
|
|
||||||
output2 = llm.generate(prompts, sampling_params)
|
|
||||||
assert output[0].outputs[0].text != output2[0].outputs[0].text
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
# Unsupported `seed` param.
|
|
||||||
sampling_params = SamplingParams(temperature=0.3, seed=42)
|
|
||||||
output2 = llm.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
# Batch-case with TopK/P
|
|
||||||
for B in [4, 16]:
|
|
||||||
p = prompts * B
|
|
||||||
sampling_params = [
|
|
||||||
SamplingParams(
|
|
||||||
temperature=0.1,
|
|
||||||
min_p=0.8,
|
|
||||||
max_tokens=64,
|
|
||||||
# Vary number of ks
|
|
||||||
top_k=random.randint(4, 12),
|
|
||||||
top_p=random.random(),
|
|
||||||
)
|
|
||||||
for _ in range(B)
|
|
||||||
]
|
|
||||||
# Make sure first two reqs have the same K/P
|
|
||||||
sampling_params[0] = sampling_params[1]
|
|
||||||
output = llm.generate(p, sampling_params)
|
|
||||||
# There are natural numerical instabilities that make it difficult
|
|
||||||
# to have deterministic results over many tokens, tests the first ~20
|
|
||||||
# tokens match.
|
|
||||||
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
|
||||||
# TODO TPU will appear busy if we fan-out test params here
|
|
||||||
@pytest.mark.parametrize("n_prompts", [1])
|
|
||||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
|
|
||||||
def test_logprobs(model_name: str, n_prompts: int):
|
|
||||||
"""
|
|
||||||
Request top logprobs with different sampling settings and check
|
|
||||||
that results contains the requested number, ordered ascendingly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def check_num_logprobs(logprobs, expected_num: int):
|
|
||||||
for step in logprobs:
|
|
||||||
prev_logp = 1.0
|
|
||||||
# order by rank
|
|
||||||
sorted_step = dict(sorted(step.items(), key=lambda item: item[1].rank))
|
|
||||||
|
|
||||||
# Can contain the sampled token
|
|
||||||
assert len(step) == expected_num or len(step) == expected_num + 1
|
|
||||||
# Check results are ordered by prob value
|
|
||||||
for rankno, (tid, logp) in enumerate(sorted_step.items()):
|
|
||||||
assert logp.logprob <= prev_logp
|
|
||||||
prev_logp = logp.logprob
|
|
||||||
assert logp.rank == rankno + 1
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
model_name,
|
|
||||||
enforce_eager=False,
|
|
||||||
max_num_seqs=1,
|
|
||||||
max_model_len=128,
|
|
||||||
max_num_batched_tokens=128,
|
|
||||||
)
|
|
||||||
prompts = [
|
|
||||||
"Write a short story about a robot that dreams for the first time."
|
|
||||||
] * n_prompts
|
|
||||||
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64, logprobs=4)
|
|
||||||
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64, logprobs=4)
|
|
||||||
topkp_sampling_params = SamplingParams(
|
|
||||||
temperature=0.4, max_tokens=64, logprobs=4, top_k=12, top_p=0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
for sp in [greedy_sampling_params, regular_sampling_params, topkp_sampling_params]:
|
|
||||||
output = llm.generate(prompts, sp)
|
|
||||||
for o in output:
|
|
||||||
check_num_logprobs(o.outputs[0].logprobs, 4)
|
|
||||||
@ -1,78 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import gc
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch_xla.distributed.spmd as xs
|
|
||||||
import torch_xla.runtime as xr
|
|
||||||
|
|
||||||
from vllm.config import set_current_vllm_config
|
|
||||||
from vllm.distributed.parallel_state import (
|
|
||||||
ensure_model_parallel_initialized,
|
|
||||||
init_distributed_environment,
|
|
||||||
)
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
|
||||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_environment(model):
|
|
||||||
engine_args = EngineArgs(
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
vllm_config = engine_args.create_engine_config()
|
|
||||||
with set_current_vllm_config(vllm_config):
|
|
||||||
temp_file = tempfile.mkstemp()[1]
|
|
||||||
init_distributed_environment(
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
local_rank=0,
|
|
||||||
distributed_init_method=f"file://{temp_file}",
|
|
||||||
backend="gloo",
|
|
||||||
)
|
|
||||||
# Under single worker mode, full model is init first and then
|
|
||||||
# partitioned using GSPMD.
|
|
||||||
ensure_model_parallel_initialized(1, 1)
|
|
||||||
return vllm_config
|
|
||||||
|
|
||||||
|
|
||||||
MESH = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_spmd_mesh():
|
|
||||||
global MESH
|
|
||||||
if MESH is None:
|
|
||||||
xr.use_spmd()
|
|
||||||
num_devices = xr.global_runtime_device_count()
|
|
||||||
mesh_shape = (num_devices, 1)
|
|
||||||
device_ids = np.array(range(num_devices))
|
|
||||||
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
|
|
||||||
return MESH
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model",
|
|
||||||
[
|
|
||||||
"Qwen/Qwen2-1.5B-Instruct",
|
|
||||||
# Skip large models due to CI runner disk space limitations
|
|
||||||
# "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
# "meta-llama/Llama-3.1-70B-Instruct",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_tpu_model_loader(model):
|
|
||||||
# Skip the 70B test if there are less than 8 chips
|
|
||||||
# TODO: Query using torch xla API, the query API is not working
|
|
||||||
# with SPMD now. However, This test is running under SPMD mode.
|
|
||||||
if "70B" in model and xr.global_runtime_device_count() < 8:
|
|
||||||
pytest.skip(
|
|
||||||
"Skipping 70B model if the TPU VM has less than 8 chips to \
|
|
||||||
avoid OOM."
|
|
||||||
)
|
|
||||||
|
|
||||||
vllm_config = _setup_environment(model)
|
|
||||||
loader = TPUModelLoader(load_config=vllm_config.load_config)
|
|
||||||
mesh = _get_spmd_mesh()
|
|
||||||
model = loader.load_model(vllm_config, vllm_config.model_config, mesh)
|
|
||||||
del model
|
|
||||||
gc.collect()
|
|
||||||
@ -1,149 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import math
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch_xla
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
|
||||||
from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu
|
|
||||||
|
|
||||||
if not current_platform.is_tpu():
|
|
||||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
BATCH_SIZE = 1024
|
|
||||||
VOCAB_SIZE = 128 * 1024
|
|
||||||
TOLERANCE = 1e-6
|
|
||||||
|
|
||||||
|
|
||||||
def test_topk_equivalence_to_native_impl():
|
|
||||||
with torch.device(xm.xla_device()):
|
|
||||||
xm.set_rng_state(seed=33)
|
|
||||||
|
|
||||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
|
|
||||||
|
|
||||||
# Random top-k values between 1 and 10.
|
|
||||||
k = torch.randint(1, 10, (BATCH_SIZE,))
|
|
||||||
|
|
||||||
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
|
||||||
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), VOCAB_SIZE)
|
|
||||||
|
|
||||||
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
|
|
||||||
|
|
||||||
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
|
||||||
assert torch.allclose(result_native, result_tpu)
|
|
||||||
|
|
||||||
|
|
||||||
def test_topp_result_sums_past_p():
|
|
||||||
with torch.device(xm.xla_device()):
|
|
||||||
xm.set_rng_state(seed=33)
|
|
||||||
|
|
||||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
|
|
||||||
probs = logits.softmax(dim=-1)
|
|
||||||
|
|
||||||
# Random top-p values between 0 and 1.
|
|
||||||
p = torch.rand((BATCH_SIZE,))
|
|
||||||
|
|
||||||
# Set p=1 for ~50% of requests in the batch (top-p disabled).
|
|
||||||
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), 1)
|
|
||||||
|
|
||||||
no_op_k = torch.tensor([VOCAB_SIZE])
|
|
||||||
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), k=no_op_k, p=p)
|
|
||||||
|
|
||||||
# Verify that the masked logit's probability sums to at least p.
|
|
||||||
probs.masked_fill_(logits_masked.isinf(), 0)
|
|
||||||
masked_prob_sum = probs.sum(dim=-1)
|
|
||||||
|
|
||||||
torch_xla.sync()
|
|
||||||
|
|
||||||
# Perform assertion on CPU.
|
|
||||||
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
|
|
||||||
|
|
||||||
|
|
||||||
def test_topp_basic():
|
|
||||||
with torch.device(xm.xla_device()):
|
|
||||||
logits = torch.tensor(
|
|
||||||
[
|
|
||||||
[math.log(0.2), math.log(0.3), math.log(0.5)],
|
|
||||||
[math.log(0.5), math.log(0.1), math.log(0.4)],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = apply_top_k_top_p_tpu(
|
|
||||||
logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79])
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_xla.sync()
|
|
||||||
|
|
||||||
# Expect the smallest elements to be dropped.
|
|
||||||
expected_result = logits.clone().cpu()
|
|
||||||
expected_result[0, 0] = float("-inf")
|
|
||||||
expected_result[1, 1] = float("-inf")
|
|
||||||
assert torch.allclose(expected_result, result.cpu())
|
|
||||||
|
|
||||||
|
|
||||||
def test_topp_select_all():
|
|
||||||
with torch.device(xm.xla_device()):
|
|
||||||
logits = torch.tensor(
|
|
||||||
[
|
|
||||||
[math.log(0.2), math.log(0.3), math.log(0.5)],
|
|
||||||
[math.log(0.5), math.log(0.1), math.log(0.4)],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = apply_top_k_top_p_tpu(
|
|
||||||
logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0])
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_xla.sync()
|
|
||||||
|
|
||||||
assert torch.allclose(logits.cpu(), result.cpu())
|
|
||||||
|
|
||||||
|
|
||||||
def test_topp_with_ties():
|
|
||||||
with torch.device(xm.xla_device()):
|
|
||||||
# Input has multiple math.log(0.3).
|
|
||||||
logits = torch.tensor(
|
|
||||||
[[math.log(0.3), math.log(0.3), math.log(0.3), math.log(0.1)]]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = apply_top_k_top_p_tpu(
|
|
||||||
logits=logits.clone(), k=torch.tensor([4]), p=torch.tensor([0.2])
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_xla.sync()
|
|
||||||
|
|
||||||
# All tie values are included in the top-p set. Tie breaking is left
|
|
||||||
# to be done during final sampling (all tie tokens have equal
|
|
||||||
# probability of being chosen).
|
|
||||||
expected_result = logits.clone().cpu()
|
|
||||||
expected_result[0, 3] = float("-inf")
|
|
||||||
assert torch.allclose(expected_result, result.cpu())
|
|
||||||
|
|
||||||
|
|
||||||
def test_both_topk_topp():
|
|
||||||
with torch.device(xm.xla_device()):
|
|
||||||
logits = torch.tensor(
|
|
||||||
[
|
|
||||||
[math.log(0.2), math.log(0.3), math.log(0.5)],
|
|
||||||
[math.log(0.5), math.log(0.1), math.log(0.4)],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set k=1 for the first batch.
|
|
||||||
result = apply_top_k_top_p_tpu(
|
|
||||||
logits=logits.clone(), k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79])
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_xla.sync()
|
|
||||||
|
|
||||||
# Since for the first batch k=1, expect only the largest element gets
|
|
||||||
# selected.
|
|
||||||
expected_result = logits.clone().cpu()
|
|
||||||
expected_result[0, 0] = float("-inf")
|
|
||||||
expected_result[0, 1] = float("-inf")
|
|
||||||
expected_result[1, 1] = float("-inf")
|
|
||||||
assert torch.allclose(expected_result, result.cpu())
|
|
||||||
@ -1,78 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""Tests whether TPU Int8 computation is enabled correctly.
|
|
||||||
|
|
||||||
Run `pytest tests/quantization/test_tpu_int8.py`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
|
||||||
from vllm.model_executor.layers.quantization.tpu_int8 import TPUInt8LinearMethod
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
from ...models.registry import HF_EXAMPLE_MODELS
|
|
||||||
|
|
||||||
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not current_platform.is_tpu(), reason="TPU Int8 is only enabled for TPUs."
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
|
||||||
@pytest.mark.parametrize("max_tokens", [10])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"hf_overrides",
|
|
||||||
[
|
|
||||||
# w8a8 dynamic activation
|
|
||||||
{
|
|
||||||
"quantization_config": {
|
|
||||||
"quant_method": "tpu_int8",
|
|
||||||
"activation_scheme": "dynamic",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_model_tpu_int8(
|
|
||||||
vllm_runner,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
max_tokens: int,
|
|
||||||
hf_overrides: dict,
|
|
||||||
monkeypatch,
|
|
||||||
) -> None:
|
|
||||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
|
||||||
model_info.check_transformers_version(on_fail="skip")
|
|
||||||
|
|
||||||
activation_scheme = hf_overrides.get("quantization_config", {}).get(
|
|
||||||
"activation_scheme"
|
|
||||||
)
|
|
||||||
quantize_activation = activation_scheme == "dynamic"
|
|
||||||
|
|
||||||
# Allows using apply_model
|
|
||||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
|
||||||
# Prevent error from re-initializing cache
|
|
||||||
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
|
|
||||||
|
|
||||||
prompts = [
|
|
||||||
"A robot may not injure a human being",
|
|
||||||
]
|
|
||||||
answers = [
|
|
||||||
"or kill a human being",
|
|
||||||
]
|
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
|
||||||
|
|
||||||
def check_model(model):
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if not isinstance(module, LinearBase):
|
|
||||||
continue
|
|
||||||
quant_method = module.quant_method
|
|
||||||
assert isinstance(quant_method, TPUInt8LinearMethod)
|
|
||||||
assert quant_method.quantize_activation == quantize_activation
|
|
||||||
|
|
||||||
vllm.apply_model(check_model)
|
|
||||||
outputs = vllm.generate_greedy(prompts, max_tokens)
|
|
||||||
for (_, output), answer in zip(outputs, answers):
|
|
||||||
assert answer in output
|
|
||||||
@ -1,93 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch_xla.distributed.spmd as xs
|
|
||||||
import torch_xla.runtime as xr
|
|
||||||
|
|
||||||
from vllm.config import set_current_vllm_config
|
|
||||||
from vllm.distributed.parallel_state import (
|
|
||||||
ensure_model_parallel_initialized,
|
|
||||||
init_distributed_environment,
|
|
||||||
)
|
|
||||||
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
|
||||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_environment():
|
|
||||||
# This is a fake config used for init dist env.
|
|
||||||
# QKVParallelLinear needs dist env to be initialized.
|
|
||||||
engine_args = EngineArgs(
|
|
||||||
model="Qwen/Qwen2-1.5B-Instruct",
|
|
||||||
max_model_len=64,
|
|
||||||
max_num_batched_tokens=64,
|
|
||||||
max_num_seqs=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
vllm_config = engine_args.create_engine_config()
|
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
|
||||||
temp_file = tempfile.mkstemp()[1]
|
|
||||||
init_distributed_environment(
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
local_rank=0,
|
|
||||||
distributed_init_method=f"file://{temp_file}",
|
|
||||||
backend="gloo",
|
|
||||||
)
|
|
||||||
ensure_model_parallel_initialized(1, 1)
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
MESH = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_spmd_mesh():
|
|
||||||
global MESH
|
|
||||||
if MESH is None:
|
|
||||||
xr.use_spmd()
|
|
||||||
num_devices = xr.global_runtime_device_count()
|
|
||||||
mesh_shape = (num_devices, 1)
|
|
||||||
device_ids = np.array(range(num_devices))
|
|
||||||
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
|
|
||||||
return MESH
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("bias", [False, True])
|
|
||||||
# `xr.use_spmd()` will set a global state, and this state is not reversible.
|
|
||||||
# Therefore, non-SPMD tests should be run before SPMD tests.
|
|
||||||
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
|
|
||||||
@pytest.mark.parametrize("device", ["cpu", "xla"])
|
|
||||||
@torch.no_grad()
|
|
||||||
def test_xla_qkv_linear(bias, mesh, device):
|
|
||||||
torch.manual_seed(123)
|
|
||||||
|
|
||||||
qkv_linear = QKVParallelLinear(
|
|
||||||
hidden_size=4096,
|
|
||||||
head_size=128,
|
|
||||||
total_num_heads=32,
|
|
||||||
total_num_kv_heads=8,
|
|
||||||
bias=bias,
|
|
||||||
params_dtype=torch.bfloat16,
|
|
||||||
return_bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10
|
|
||||||
if bias:
|
|
||||||
qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data)
|
|
||||||
|
|
||||||
xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh)
|
|
||||||
|
|
||||||
qkv_linear = qkv_linear.to(device)
|
|
||||||
xla_qkv_linear = xla_qkv_linear.to(device)
|
|
||||||
input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10
|
|
||||||
input_tensor = input_tensor.to(device)
|
|
||||||
|
|
||||||
output = qkv_linear(input_tensor)
|
|
||||||
xla_output = xla_qkv_linear(input_tensor)
|
|
||||||
assert torch.allclose(output.cpu(), xla_output.cpu())
|
|
||||||
@ -1,587 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.attention.layer import Attention
|
|
||||||
from vllm.config import (
|
|
||||||
CacheConfig,
|
|
||||||
ModelConfig,
|
|
||||||
SchedulerConfig,
|
|
||||||
VllmConfig,
|
|
||||||
set_current_vllm_config,
|
|
||||||
)
|
|
||||||
from vllm.pooling_params import PoolingParams
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
|
||||||
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
|
|
||||||
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
|
||||||
from vllm.v1.worker.tpu_model_runner import (
|
|
||||||
TPUModelRunner,
|
|
||||||
_get_padded_num_reqs_with_upper_limit,
|
|
||||||
_get_padded_token_len,
|
|
||||||
_get_req_paddings,
|
|
||||||
_get_token_paddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_vllm_config():
|
|
||||||
model_config = ModelConfig(
|
|
||||||
model="facebook/opt-125m",
|
|
||||||
dtype="bfloat16", # TPUs typically use bfloat16
|
|
||||||
seed=42,
|
|
||||||
)
|
|
||||||
scheduler_config = SchedulerConfig(
|
|
||||||
max_num_seqs=10,
|
|
||||||
max_num_batched_tokens=512,
|
|
||||||
max_model_len=512,
|
|
||||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
|
||||||
)
|
|
||||||
cache_config = CacheConfig(
|
|
||||||
block_size=16,
|
|
||||||
gpu_memory_utilization=0.9,
|
|
||||||
swap_space=0,
|
|
||||||
cache_dtype="auto",
|
|
||||||
)
|
|
||||||
vllm_config = VllmConfig(
|
|
||||||
model_config=model_config,
|
|
||||||
cache_config=cache_config,
|
|
||||||
scheduler_config=scheduler_config,
|
|
||||||
)
|
|
||||||
return vllm_config
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_runner(vllm_config):
|
|
||||||
device = "xla:0" # Mocking TPU device
|
|
||||||
return TPUModelRunner(vllm_config, device)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def model_runner():
|
|
||||||
# Patchers have already been started at module level.
|
|
||||||
vllm_config = get_vllm_config()
|
|
||||||
return get_model_runner(vllm_config)
|
|
||||||
|
|
||||||
|
|
||||||
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|
||||||
new_reqs = []
|
|
||||||
num_scheduled_tokens = {}
|
|
||||||
total_num_scheduled_tokens = 0
|
|
||||||
for req_id in req_ids:
|
|
||||||
new_reqs.append(
|
|
||||||
NewRequestData(
|
|
||||||
req_id=req_id,
|
|
||||||
prompt_token_ids=[1, 2, 3],
|
|
||||||
mm_features=[],
|
|
||||||
sampling_params=SamplingParams(),
|
|
||||||
pooling_params=PoolingParams(),
|
|
||||||
block_ids=([0],), # block_ids should be tuple[list[int]]
|
|
||||||
num_computed_tokens=0,
|
|
||||||
lora_request=None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
num_scheduled_tokens[req_id] = 3
|
|
||||||
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
|
|
||||||
|
|
||||||
return SchedulerOutput(
|
|
||||||
scheduled_new_reqs=new_reqs,
|
|
||||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
|
||||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
|
||||||
scheduled_spec_decode_tokens={},
|
|
||||||
scheduled_encoder_inputs={},
|
|
||||||
num_common_prefix_blocks=[],
|
|
||||||
finished_req_ids=set(),
|
|
||||||
free_encoder_mm_hashes=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_req_scheduled(model_runner, req_id: str) -> bool:
|
|
||||||
return req_id in model_runner.input_batch.req_id_to_index
|
|
||||||
|
|
||||||
|
|
||||||
def _is_req_added(model_runner, req_id: str) -> bool:
|
|
||||||
return req_id in model_runner.requests
|
|
||||||
|
|
||||||
|
|
||||||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
|
||||||
"""Check if the request state block IDs match the block table.
|
|
||||||
|
|
||||||
This function handles both legacy BlockTable and new MultiGroupBlockTable
|
|
||||||
structures for backward compatibility.
|
|
||||||
"""
|
|
||||||
|
|
||||||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
|
||||||
multi_group_block_table = model_runner.input_batch.block_table
|
|
||||||
req_state = model_runner.requests[req_id]
|
|
||||||
|
|
||||||
# Access the first block table from MultiGroupBlockTable
|
|
||||||
# This is safe since we currently only use single KV cache groups
|
|
||||||
block_table = multi_group_block_table[0]
|
|
||||||
|
|
||||||
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
|
|
||||||
# Extract the first group's block IDs
|
|
||||||
if isinstance(req_state.block_ids[0], list):
|
|
||||||
# New format: tuple[list[int], ...] - extract first group
|
|
||||||
req_block_ids = req_state.block_ids[0]
|
|
||||||
else:
|
|
||||||
# Legacy format: list[int] - use directly
|
|
||||||
req_block_ids = req_state.block_ids
|
|
||||||
|
|
||||||
if block_table.num_blocks_per_row[req_index] != len(req_block_ids):
|
|
||||||
return False
|
|
||||||
|
|
||||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
|
||||||
block_table_values = block_table.block_table.np[req_index, :num_blocks]
|
|
||||||
return (block_table_values == req_block_ids).all()
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_states_new_request(model_runner):
|
|
||||||
req_id = "req_0"
|
|
||||||
|
|
||||||
# new req
|
|
||||||
scheduler_output = _schedule_new_request(req_id)
|
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
|
||||||
|
|
||||||
assert _is_req_added(model_runner, req_id)
|
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
|
||||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_states_request_finished(model_runner):
|
|
||||||
req_id = "req_0"
|
|
||||||
|
|
||||||
# new req
|
|
||||||
scheduler_output = _schedule_new_request(req_id)
|
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
|
||||||
assert _is_req_added(model_runner, req_id)
|
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
|
||||||
|
|
||||||
# finish req
|
|
||||||
scheduler_output = SchedulerOutput(
|
|
||||||
scheduled_new_reqs=[],
|
|
||||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
||||||
num_scheduled_tokens={},
|
|
||||||
total_num_scheduled_tokens=0,
|
|
||||||
scheduled_spec_decode_tokens={},
|
|
||||||
scheduled_encoder_inputs={},
|
|
||||||
num_common_prefix_blocks=[],
|
|
||||||
finished_req_ids={req_id},
|
|
||||||
free_encoder_mm_hashes=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
|
||||||
assert not _is_req_added(model_runner, req_id)
|
|
||||||
assert not _is_req_scheduled(model_runner, req_id)
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_states_request_resumed(model_runner):
|
|
||||||
req_id = "req_0"
|
|
||||||
|
|
||||||
# new req
|
|
||||||
scheduler_output = _schedule_new_request(req_id)
|
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
|
||||||
assert _is_req_added(model_runner, req_id)
|
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
|
||||||
|
|
||||||
# unschedule req
|
|
||||||
scheduler_output = SchedulerOutput(
|
|
||||||
scheduled_new_reqs=[],
|
|
||||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
||||||
num_scheduled_tokens={},
|
|
||||||
total_num_scheduled_tokens=0,
|
|
||||||
scheduled_spec_decode_tokens={},
|
|
||||||
scheduled_encoder_inputs={},
|
|
||||||
num_common_prefix_blocks=[],
|
|
||||||
finished_req_ids=set(),
|
|
||||||
free_encoder_mm_hashes=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
|
||||||
assert _is_req_added(model_runner, req_id)
|
|
||||||
assert not _is_req_scheduled(model_runner, req_id)
|
|
||||||
|
|
||||||
# resume req
|
|
||||||
cached_req_data = CachedRequestData(
|
|
||||||
req_ids=[req_id],
|
|
||||||
resumed_req_ids={req_id},
|
|
||||||
new_token_ids=[[]],
|
|
||||||
all_token_ids={req_id: scheduler_output.scheduled_new_reqs[0].prompt_token_ids},
|
|
||||||
new_block_ids=[([],)],
|
|
||||||
num_computed_tokens=[0],
|
|
||||||
num_output_tokens=[0],
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(
|
|
||||||
scheduled_new_reqs=[],
|
|
||||||
scheduled_cached_reqs=cached_req_data,
|
|
||||||
num_scheduled_tokens={req_id: 1},
|
|
||||||
total_num_scheduled_tokens=1,
|
|
||||||
scheduled_spec_decode_tokens={},
|
|
||||||
scheduled_encoder_inputs={},
|
|
||||||
num_common_prefix_blocks=[],
|
|
||||||
finished_req_ids=set(),
|
|
||||||
free_encoder_mm_hashes=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
|
||||||
assert _is_req_added(model_runner, req_id)
|
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
|
||||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_states_no_changes(model_runner):
|
|
||||||
req_id = "req_0"
|
|
||||||
|
|
||||||
# new req
|
|
||||||
scheduler_output = _schedule_new_request(req_id)
|
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
|
||||||
assert _is_req_added(model_runner, req_id)
|
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
|
||||||
|
|
||||||
# schedule req
|
|
||||||
scheduler_output = SchedulerOutput(
|
|
||||||
scheduled_new_reqs=[],
|
|
||||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
||||||
num_scheduled_tokens={req_id: 1},
|
|
||||||
total_num_scheduled_tokens=1,
|
|
||||||
scheduled_spec_decode_tokens={},
|
|
||||||
scheduled_encoder_inputs={},
|
|
||||||
num_common_prefix_blocks=[],
|
|
||||||
finished_req_ids=set(),
|
|
||||||
free_encoder_mm_hashes=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
|
||||||
assert _is_req_added(model_runner, req_id)
|
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
|
||||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_states_request_unscheduled(model_runner):
|
|
||||||
req_ids = ("req_0", "req_1")
|
|
||||||
|
|
||||||
# new reqs
|
|
||||||
scheduler_output = _schedule_new_request(*req_ids)
|
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
|
||||||
|
|
||||||
assert _is_req_added(model_runner, req_ids[0])
|
|
||||||
assert _is_req_scheduled(model_runner, req_ids[0])
|
|
||||||
|
|
||||||
assert _is_req_added(model_runner, req_ids[1])
|
|
||||||
assert _is_req_scheduled(model_runner, req_ids[1])
|
|
||||||
|
|
||||||
# unschedule req_1
|
|
||||||
scheduler_output = SchedulerOutput(
|
|
||||||
scheduled_new_reqs=[],
|
|
||||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
||||||
num_scheduled_tokens={req_ids[0]: 1},
|
|
||||||
total_num_scheduled_tokens=1,
|
|
||||||
scheduled_spec_decode_tokens={},
|
|
||||||
scheduled_encoder_inputs={},
|
|
||||||
num_common_prefix_blocks=[],
|
|
||||||
finished_req_ids=set(),
|
|
||||||
free_encoder_mm_hashes=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
|
||||||
|
|
||||||
assert _is_req_added(model_runner, req_ids[0])
|
|
||||||
assert _is_req_scheduled(model_runner, req_ids[0])
|
|
||||||
|
|
||||||
assert _is_req_added(model_runner, req_ids[1])
|
|
||||||
assert not _is_req_scheduled(model_runner, req_ids[1])
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_paddings():
|
|
||||||
# Bucketed padding
|
|
||||||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
|
||||||
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
|
|
||||||
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
|
||||||
|
|
||||||
# Bucketed padding with max_token_size not a power of two.
|
|
||||||
max_token_size = 317
|
|
||||||
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
|
|
||||||
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
|
||||||
assert actual_paddings == expected_paddings
|
|
||||||
|
|
||||||
# Exponential padding.
|
|
||||||
max_token_size, padding_gap = 1024, 0
|
|
||||||
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
|
|
||||||
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
|
||||||
assert actual_paddings == expected_paddings
|
|
||||||
# Exponential padding with max_token_size not a power of two.
|
|
||||||
max_token_size = 317
|
|
||||||
expected_paddings = [16, 32, 64, 128, 256, 512]
|
|
||||||
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
|
||||||
assert actual_paddings == expected_paddings
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_padded_token_len():
|
|
||||||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
|
||||||
paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
|
||||||
assert _get_padded_token_len(paddings, 1) == 16
|
|
||||||
assert _get_padded_token_len(paddings, 16) == 16
|
|
||||||
assert _get_padded_token_len(paddings, 20) == 32
|
|
||||||
assert _get_padded_token_len(paddings, 300) == 320
|
|
||||||
assert _get_padded_token_len(paddings, 512) == 512
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_padded_num_reqs_with_upper_limit():
|
|
||||||
assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8
|
|
||||||
assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16
|
|
||||||
assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32
|
|
||||||
assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_req_paddings():
|
|
||||||
assert _get_req_paddings(1, 32) == [8, 16, 32]
|
|
||||||
assert _get_req_paddings(8, 32) == [8, 16, 32]
|
|
||||||
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(model_runner):
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
|
||||||
layer_1 = "model.layers.1.self_attn.attn"
|
|
||||||
error_msg = f"{layer_1} must come before the current layer"
|
|
||||||
vllm_config = model_runner.vllm_config
|
|
||||||
with (
|
|
||||||
pytest.raises(ValueError, match=error_msg),
|
|
||||||
set_current_vllm_config(vllm_config),
|
|
||||||
):
|
|
||||||
fwd_context = {
|
|
||||||
# initialization below will fail because target layer is invalid;
|
|
||||||
# the target layer needs to come before layer 1
|
|
||||||
layer_0: Attention(
|
|
||||||
num_heads=8,
|
|
||||||
head_size=128,
|
|
||||||
scale=1.0,
|
|
||||||
prefix=layer_0,
|
|
||||||
kv_sharing_target_layer_name=layer_1,
|
|
||||||
),
|
|
||||||
layer_1: Attention(
|
|
||||||
num_heads=8,
|
|
||||||
head_size=128,
|
|
||||||
scale=1.0,
|
|
||||||
prefix=layer_1,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
# suppress var not used error
|
|
||||||
assert fwd_context is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
|
||||||
layer_1 = "model.layers.1.self_attn.attn"
|
|
||||||
invalid_layer = "model.layers.0.cross_attn.attn"
|
|
||||||
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
|
|
||||||
vllm_config = model_runner.vllm_config
|
|
||||||
with (
|
|
||||||
pytest.raises(ValueError, match=error_msg),
|
|
||||||
set_current_vllm_config(vllm_config),
|
|
||||||
):
|
|
||||||
fwd_context = {
|
|
||||||
layer_0: Attention(
|
|
||||||
num_heads=8,
|
|
||||||
head_size=128,
|
|
||||||
scale=1.0,
|
|
||||||
prefix=layer_0,
|
|
||||||
),
|
|
||||||
layer_1: Attention(
|
|
||||||
num_heads=8,
|
|
||||||
head_size=128,
|
|
||||||
scale=1.0,
|
|
||||||
prefix=layer_1,
|
|
||||||
# invalid layer: cross_attn.atn doesn't exist!
|
|
||||||
kv_sharing_target_layer_name=invalid_layer,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
# suppress var not used error
|
|
||||||
assert fwd_context is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner):
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
|
||||||
layer_1 = "model.layers.1.self_attn.attn"
|
|
||||||
error_msg = f"{layer_1} cannot be the same as the current layer"
|
|
||||||
vllm_config = model_runner.vllm_config
|
|
||||||
with (
|
|
||||||
pytest.raises(ValueError, match=error_msg),
|
|
||||||
set_current_vllm_config(vllm_config),
|
|
||||||
):
|
|
||||||
fwd_context = {
|
|
||||||
# initialization below will fail because target layer is invalid;
|
|
||||||
# the target layer needs to come before layer 1
|
|
||||||
layer_0: Attention(
|
|
||||||
num_heads=8,
|
|
||||||
head_size=128,
|
|
||||||
scale=1.0,
|
|
||||||
prefix=layer_0,
|
|
||||||
),
|
|
||||||
layer_1: Attention(
|
|
||||||
num_heads=8,
|
|
||||||
head_size=128,
|
|
||||||
scale=1.0,
|
|
||||||
prefix=layer_1,
|
|
||||||
kv_sharing_target_layer_name=layer_1,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
# suppress var not used error
|
|
||||||
assert fwd_context is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_without_kv_sharing():
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
|
||||||
layer_1 = "model.layers.1.self_attn.attn"
|
|
||||||
vllm_config = get_vllm_config()
|
|
||||||
with set_current_vllm_config(vllm_config):
|
|
||||||
fwd_context = {
|
|
||||||
layer_0: Attention(
|
|
||||||
num_heads=8,
|
|
||||||
head_size=128,
|
|
||||||
scale=1.0,
|
|
||||||
prefix=layer_0,
|
|
||||||
),
|
|
||||||
layer_1: Attention(
|
|
||||||
num_heads=8,
|
|
||||||
head_size=128,
|
|
||||||
scale=1.0,
|
|
||||||
prefix=layer_1,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
# suppress var not used error
|
|
||||||
assert fwd_context is not None
|
|
||||||
# Set high context length to test max context length estimation
|
|
||||||
vllm_config.model_config.max_model_len = 1_000_000
|
|
||||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
|
||||||
model_runner = get_model_runner(vllm_config)
|
|
||||||
kv_cache_spec = model_runner.get_kv_cache_spec()
|
|
||||||
assert len(kv_cache_spec) == 2
|
|
||||||
assert len(model_runner.shared_kv_cache_layers) == 0
|
|
||||||
|
|
||||||
available_memory = 20 * GiB_bytes
|
|
||||||
# page size for each layer KV can be calculated as
|
|
||||||
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
|
|
||||||
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
|
|
||||||
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
|
|
||||||
kv_cache_config = get_kv_cache_configs(
|
|
||||||
vllm_config, [kv_cache_spec], [available_memory]
|
|
||||||
)[0]
|
|
||||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
|
||||||
assert len(kv_cache_config.kv_cache_tensors) == 2
|
|
||||||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
|
|
||||||
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
|
|
||||||
|
|
||||||
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
|
||||||
# max context len with KV sharing should be 2x as large as without
|
|
||||||
# max_context_len = available_memory / (page_size / block_size) / num_caches
|
|
||||||
# max_context_len = 5GB / (512KB / 128) / 2 = 655360
|
|
||||||
assert max_context_len == 655360
|
|
||||||
|
|
||||||
# important: override tensor size to prevent large mem alloc during test
|
|
||||||
# this will only allocate 2 block worth of memory (2 * 512kb)
|
|
||||||
kv_cache_config.num_blocks = 1
|
|
||||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
|
||||||
kv_cache_tensor.size = kv_cache_spec[
|
|
||||||
kv_cache_tensor.shared_by[0]
|
|
||||||
].page_size_bytes
|
|
||||||
|
|
||||||
model_runner.initialize_kv_cache(kv_cache_config)
|
|
||||||
|
|
||||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
|
||||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
|
||||||
# check layer 1 kv cache does NOT share memory with layer 0
|
|
||||||
assert id(layer_1_kv) != id(layer_0_kv)
|
|
||||||
|
|
||||||
# check layer 1 added to kv cache group's layer names
|
|
||||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
|
||||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
|
||||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
|
||||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_with_kv_sharing_valid():
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
|
||||||
layer_1 = "model.layers.1.self_attn.attn"
|
|
||||||
vllm_config = get_vllm_config()
|
|
||||||
with set_current_vllm_config(vllm_config):
|
|
||||||
fwd_context = {
|
|
||||||
layer_0: Attention(
|
|
||||||
num_heads=8,
|
|
||||||
head_size=128,
|
|
||||||
scale=1.0,
|
|
||||||
prefix=layer_0,
|
|
||||||
),
|
|
||||||
layer_1: Attention(
|
|
||||||
num_heads=8,
|
|
||||||
head_size=128,
|
|
||||||
scale=1.0,
|
|
||||||
prefix=layer_1,
|
|
||||||
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
# suppress var not used error
|
|
||||||
assert fwd_context is not None
|
|
||||||
# Set high context length to test max context length estimation
|
|
||||||
vllm_config.model_config.max_model_len = 3_000_000
|
|
||||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
|
||||||
model_runner = get_model_runner(vllm_config)
|
|
||||||
kv_cache_spec = model_runner.get_kv_cache_spec()
|
|
||||||
assert len(kv_cache_spec) == 1
|
|
||||||
assert layer_0 in kv_cache_spec
|
|
||||||
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
|
|
||||||
|
|
||||||
available_memory = 20 * GiB_bytes
|
|
||||||
# page size for layer 0's kv_cache_spec is 512KB
|
|
||||||
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
|
|
||||||
# which is twice as many as without KV sharing
|
|
||||||
num_expected_blocks = 2 * 20480 # 20GB / 512KB
|
|
||||||
kv_cache_config = get_kv_cache_configs(
|
|
||||||
vllm_config, [kv_cache_spec], [available_memory]
|
|
||||||
)[0]
|
|
||||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
|
||||||
assert len(kv_cache_config.kv_cache_tensors) == 1
|
|
||||||
# Each layer now has twice the available memory for KV cache
|
|
||||||
# compared to no KV sharing
|
|
||||||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
|
|
||||||
|
|
||||||
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
|
||||||
# max context len with KV sharing should be 2x as large as without
|
|
||||||
assert max_context_len == (2 * 655360)
|
|
||||||
|
|
||||||
# important: override tensor size to prevent large mem alloc during test
|
|
||||||
# this will only allocate 1 block worth of memory (512kb)
|
|
||||||
kv_cache_config.num_blocks = 1
|
|
||||||
kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes
|
|
||||||
|
|
||||||
model_runner.initialize_kv_cache(kv_cache_config)
|
|
||||||
|
|
||||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
|
||||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
|
||||||
# check layer 1 kv cache shares memory with layer 0
|
|
||||||
assert id(layer_1_kv) == id(layer_0_kv)
|
|
||||||
|
|
||||||
# check layer 1 added to kv cache group's layer names
|
|
||||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
|
||||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
|
||||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
|
||||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
|
||||||
|
|
||||||
|
|
||||||
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
|
|
||||||
vllm_config = get_vllm_config()
|
|
||||||
vllm_config.model_config.max_model_len = 32000
|
|
||||||
vllm_config.scheduler_config.max_num_seqs = 1200
|
|
||||||
model_runner = get_model_runner(vllm_config)
|
|
||||||
|
|
||||||
# verify model runner will adjust num_reqs to avoid SMEM OOM.
|
|
||||||
assert model_runner.num_reqs_most_model_len == 1200
|
|
||||||
# num_page_per_req = 32k // 128
|
|
||||||
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
|
|
||||||
assert model_runner.num_reqs_max_model_len == 524
|
|
||||||
@ -66,7 +66,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
|||||||
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
|
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
|
||||||
)
|
)
|
||||||
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
|
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
|
||||||
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
|
||||||
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
|
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
|
||||||
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
|
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
|
||||||
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||||
|
|||||||
@ -227,28 +227,3 @@ class MMEncoderAttention(CustomOp):
|
|||||||
"XPU only supports FLASH_ATTN for vision attention."
|
"XPU only supports FLASH_ATTN for vision attention."
|
||||||
)
|
)
|
||||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||||
|
|
||||||
def forward_tpu(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
|
||||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
|
|
||||||
f"MMEncoderAttention on TPU only supports PALLAS backend, "
|
|
||||||
f"but got {self.attn_backend}."
|
|
||||||
)
|
|
||||||
if cu_seqlens is None:
|
|
||||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
|
||||||
from torch_xla.experimental.custom_kernel import flash_attention
|
|
||||||
|
|
||||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
|
||||||
out = out.transpose(1, 2)
|
|
||||||
return out
|
|
||||||
logger.warning_once(
|
|
||||||
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
|
|
||||||
"Falling back to SDPA implementation.",
|
|
||||||
)
|
|
||||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
|
||||||
|
|||||||
@ -1,99 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
|
|
||||||
from vllm.config import get_current_vllm_config
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
|
||||||
|
|
||||||
from .base_device_communicator import DeviceCommunicatorBase
|
|
||||||
|
|
||||||
USE_RAY = parallel_config = (
|
|
||||||
get_current_vllm_config().parallel_config.distributed_executor_backend == "ray"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
if not USE_TPU_INFERENCE:
|
|
||||||
logger.info("tpu_inference not found, using vLLM's TpuCommunicator")
|
|
||||||
if current_platform.is_tpu():
|
|
||||||
import torch_xla
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
import torch_xla.runtime as xr
|
|
||||||
from torch_xla._internal import pjrt
|
|
||||||
from torch_xla.distributed.xla_multiprocessing import (
|
|
||||||
create_optimized_replica_groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
if USE_RAY:
|
|
||||||
from vllm.v1.executor import ray_utils
|
|
||||||
|
|
||||||
|
|
||||||
class TpuCommunicator(DeviceCommunicatorBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cpu_group: ProcessGroup,
|
|
||||||
device: torch.device | None = None,
|
|
||||||
device_group: ProcessGroup | None = None,
|
|
||||||
unique_name: str = "",
|
|
||||||
):
|
|
||||||
super().__init__(cpu_group, device, device_group, unique_name)
|
|
||||||
|
|
||||||
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
|
|
||||||
# must be used together. Therefore, the local rank and world size can
|
|
||||||
# be simply calculated as follows.
|
|
||||||
global_rank = self.global_rank
|
|
||||||
global_world_size = self.global_world_size
|
|
||||||
|
|
||||||
if USE_RAY:
|
|
||||||
logger.info("TpuCommunicator initialized with RAY")
|
|
||||||
# Calculate how many TPU nodes are in the current deployment. This
|
|
||||||
# is the Ray placement group if it is deployed with Ray. Default
|
|
||||||
# to the number of TPU nodes in the Ray cluster. The number of TPU
|
|
||||||
# nodes is computed by the total number of TPUs divided by the
|
|
||||||
# number of TPU accelerators per node, to account for clusters
|
|
||||||
# with both CPUs and TPUs.
|
|
||||||
num_nodes = ray_utils.get_num_tpu_nodes()
|
|
||||||
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
|
|
||||||
if num_nodes_in_pg > 0:
|
|
||||||
num_nodes = num_nodes_in_pg
|
|
||||||
|
|
||||||
local_world_size = global_world_size // num_nodes
|
|
||||||
local_rank = global_rank % local_world_size
|
|
||||||
else:
|
|
||||||
logger.info("TpuCommunicator initialized with MP")
|
|
||||||
# Sanity: Verify we run on a single host
|
|
||||||
num_hosts = torch_xla.tpu.num_tpu_workers()
|
|
||||||
assert num_hosts == 1
|
|
||||||
|
|
||||||
# Get the current number of TPUs (we have locally)
|
|
||||||
local_world_size = torch_xla.tpu.num_available_chips()
|
|
||||||
|
|
||||||
# Get current rank
|
|
||||||
local_rank = global_rank % local_world_size
|
|
||||||
|
|
||||||
# Ensure environment variables are set for multihost deployments.
|
|
||||||
# On GKE, this is needed for libtpu and TPU driver to know which TPU
|
|
||||||
# chip is actually visible. Otherwise the TPU driver will fail to
|
|
||||||
# initialize because the number of devices would be different from
|
|
||||||
# the number of visible worker addresses.
|
|
||||||
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
|
|
||||||
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
|
|
||||||
|
|
||||||
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
|
||||||
xr._init_world_size_ordinal()
|
|
||||||
self.groups = create_optimized_replica_groups()
|
|
||||||
|
|
||||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
|
||||||
# TODO: Remove the groups specification after XLA compiler can support
|
|
||||||
# auto-reordering the ring order for all-reduce.
|
|
||||||
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
|
|
||||||
|
|
||||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
||||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
|
||||||
return xm.all_gather(input_, dim=dim)
|
|
||||||
@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Literal
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -251,9 +250,6 @@ class TpKVTopology:
|
|||||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
|
||||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_kv_layout_blocks_first(self) -> bool:
|
def is_kv_layout_blocks_first(self) -> bool:
|
||||||
return self._is_kv_layout_blocks_first
|
return self._is_kv_layout_blocks_first
|
||||||
@ -261,7 +257,7 @@ class TpKVTopology:
|
|||||||
@property
|
@property
|
||||||
def split_k_and_v(self) -> bool:
|
def split_k_and_v(self) -> bool:
|
||||||
# Whether to register regions for K and V separately (when present).
|
# Whether to register regions for K and V separately (when present).
|
||||||
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
|
return not (self.is_mla or self.is_kv_layout_blocks_first)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tp_size(self) -> int:
|
def tp_size(self) -> int:
|
||||||
|
|||||||
@ -499,7 +499,6 @@ class MooncakeConnectorWorker:
|
|||||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||||
attn_backend=backend,
|
attn_backend=backend,
|
||||||
)
|
)
|
||||||
self._use_pallas = self.kv_topo._use_pallas
|
|
||||||
|
|
||||||
self.zmq_ctx = zmq.Context()
|
self.zmq_ctx = zmq.Context()
|
||||||
self.async_zmq_ctx = zmq.asyncio.Context()
|
self.async_zmq_ctx = zmq.asyncio.Context()
|
||||||
|
|||||||
@ -983,7 +983,6 @@ class NixlConnectorWorker:
|
|||||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||||
attn_backend=backend,
|
attn_backend=backend,
|
||||||
)
|
)
|
||||||
self._use_pallas = self.kv_topo._use_pallas
|
|
||||||
self._physical_blocks_per_logical_kv_block = 1
|
self._physical_blocks_per_logical_kv_block = 1
|
||||||
|
|
||||||
def _nixl_handshake(
|
def _nixl_handshake(
|
||||||
@ -1641,9 +1640,6 @@ class NixlConnectorWorker:
|
|||||||
# Num kv_heads > tp_size and P TP > D TP case, not supported
|
# Num kv_heads > tp_size and P TP > D TP case, not supported
|
||||||
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
|
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
|
||||||
|
|
||||||
assert not self._use_pallas or tp_ratio == 1, (
|
|
||||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
|
||||||
)
|
|
||||||
kv_cache_layout = (
|
kv_cache_layout = (
|
||||||
self.kv_cache_layout
|
self.kv_cache_layout
|
||||||
if not self.use_host_buffer
|
if not self.use_host_buffer
|
||||||
@ -1814,9 +1810,7 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
if len(self.device_kv_caches) == 0:
|
if len(self.device_kv_caches) == 0:
|
||||||
return
|
return
|
||||||
split_k_and_v = not (
|
split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
|
||||||
self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
|
|
||||||
)
|
|
||||||
sample_cache = list(self.device_kv_caches.values())[0][0]
|
sample_cache = list(self.device_kv_caches.values())[0][0]
|
||||||
for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
|
for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
|
||||||
assert block_size_ratio > 1, "Only nP < nD supported currently."
|
assert block_size_ratio > 1, "Only nP < nD supported currently."
|
||||||
|
|||||||
@ -1,188 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch_xla.distributed.spmd as xs
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.linear import (
|
|
||||||
ColumnParallelLinear,
|
|
||||||
QKVParallelLinear,
|
|
||||||
RowParallelLinear,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class XlaQKVParallelLinear(nn.Module):
|
|
||||||
def __init__(self, qkv_linear: nn.Module, mesh: Optional["xs.Mesh"] = None):
|
|
||||||
super().__init__()
|
|
||||||
assert isinstance(qkv_linear, QKVParallelLinear)
|
|
||||||
self.skip_bias_add = qkv_linear.skip_bias_add
|
|
||||||
self.return_bias = qkv_linear.return_bias
|
|
||||||
assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."
|
|
||||||
|
|
||||||
self.q_weight: Parameter
|
|
||||||
self.k_weight: Parameter
|
|
||||||
self.v_weight: Parameter
|
|
||||||
self.q_bias: Parameter | None
|
|
||||||
self.k_bias: Parameter | None
|
|
||||||
self.v_bias: Parameter | None
|
|
||||||
self._load_weights_from_qkv_linear(qkv_linear)
|
|
||||||
if mesh is not None:
|
|
||||||
self._shard_weight(mesh)
|
|
||||||
|
|
||||||
def _shard_weight(self, mesh: "xs.Mesh"):
|
|
||||||
self.q_weight = Parameter(self.q_weight.to("xla"), requires_grad=False)
|
|
||||||
self.k_weight = Parameter(self.k_weight.to("xla"), requires_grad=False)
|
|
||||||
self.v_weight = Parameter(self.v_weight.to("xla"), requires_grad=False)
|
|
||||||
xs.mark_sharding(self.q_weight, mesh, ("x", None))
|
|
||||||
xs.mark_sharding(self.k_weight, mesh, ("x", None))
|
|
||||||
xs.mark_sharding(self.v_weight, mesh, ("x", None))
|
|
||||||
if self.q_bias is not None:
|
|
||||||
assert self.k_bias is not None and self.v_bias is not None, (
|
|
||||||
"QKVParallelLinear should have q, k, and v biases together."
|
|
||||||
)
|
|
||||||
self.q_bias = Parameter(self.q_bias.to("xla"), requires_grad=False)
|
|
||||||
xs.mark_sharding(self.q_bias, mesh, ("x",))
|
|
||||||
self.k_bias = Parameter(self.k_bias.to("xla"), requires_grad=False)
|
|
||||||
xs.mark_sharding(self.k_bias, mesh, ("x",))
|
|
||||||
self.v_bias = Parameter(self.v_bias.to("xla"), requires_grad=False)
|
|
||||||
xs.mark_sharding(self.v_bias, mesh, ("x",))
|
|
||||||
|
|
||||||
def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
|
|
||||||
q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
|
|
||||||
# The weight of qkv linear is a concatenation of q, k, and v weights
|
|
||||||
# along the output dimension.
|
|
||||||
qkv_weight = qkv_linear.weight.data.cpu()
|
|
||||||
q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
|
|
||||||
k_weight = Parameter(
|
|
||||||
qkv_weight[q_proj_size : q_proj_size + k_proj_size], requires_grad=False
|
|
||||||
)
|
|
||||||
v_weight = Parameter(
|
|
||||||
qkv_weight[q_proj_size + k_proj_size :], requires_grad=False
|
|
||||||
)
|
|
||||||
self.register_parameter("q_weight", q_weight)
|
|
||||||
self.register_parameter("k_weight", k_weight)
|
|
||||||
self.register_parameter("v_weight", v_weight)
|
|
||||||
|
|
||||||
if qkv_linear.bias is not None:
|
|
||||||
q_bias = Parameter(qkv_linear.bias[:q_proj_size], requires_grad=False)
|
|
||||||
k_bias = Parameter(
|
|
||||||
qkv_linear.bias[q_proj_size : q_proj_size + k_proj_size],
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
v_bias = Parameter(
|
|
||||||
qkv_linear.bias[q_proj_size + k_proj_size :], requires_grad=False
|
|
||||||
)
|
|
||||||
self.register_parameter("q_bias", q_bias)
|
|
||||||
self.register_parameter("k_bias", k_bias)
|
|
||||||
self.register_parameter("v_bias", v_bias)
|
|
||||||
else:
|
|
||||||
self.register_parameter("q_bias", None)
|
|
||||||
self.register_parameter("k_bias", None)
|
|
||||||
self.register_parameter("v_bias", None)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
# Same forward functionality as QKVParallelLinear, but doing qkv porj
|
|
||||||
# separately.
|
|
||||||
q_bias = self.q_bias if not self.skip_bias_add else None
|
|
||||||
k_bias = self.k_bias if not self.skip_bias_add else None
|
|
||||||
v_bias = self.v_bias if not self.skip_bias_add else None
|
|
||||||
q_proj = F.linear(input, self.q_weight, q_bias)
|
|
||||||
k_proj = F.linear(input, self.k_weight, k_bias)
|
|
||||||
v_proj = F.linear(input, self.v_weight, v_bias)
|
|
||||||
# The q/k/v projections will be split outside of the QKVParallelLinear.
|
|
||||||
# Because we are replacing XlaQKVParallelLinear with the
|
|
||||||
# QKVParallelLinear, we need to concatenate q, k, and v projections to
|
|
||||||
# match the output shape of the QKVParallelLinear implementation even if
|
|
||||||
# it seems to be redundant.
|
|
||||||
# The concat and the following split will be noop, and should be
|
|
||||||
# optimized away by the compiler.
|
|
||||||
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
|
|
||||||
output_bias = (
|
|
||||||
torch.cat([q_bias, k_bias, v_bias], dim=-1) if self.skip_bias_add else None
|
|
||||||
)
|
|
||||||
if not self.return_bias:
|
|
||||||
return qkv_proj
|
|
||||||
return qkv_proj, output_bias
|
|
||||||
|
|
||||||
|
|
||||||
def partition_column_parallel_linear(
|
|
||||||
layer: torch.nn.Module, mesh: xs.Mesh
|
|
||||||
) -> torch.nn.Module:
|
|
||||||
assert isinstance(layer, ColumnParallelLinear)
|
|
||||||
xs.mark_sharding(layer.weight, mesh, ("x", None))
|
|
||||||
logger.debug("Applied column-parallel sharding to %s", layer)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def partition_row_parallel_linear(
|
|
||||||
layer: torch.nn.Module, mesh: xs.Mesh
|
|
||||||
) -> torch.nn.Module:
|
|
||||||
assert isinstance(layer, RowParallelLinear)
|
|
||||||
xs.mark_sharding(layer.weight, mesh, (None, "x"))
|
|
||||||
logger.debug("Applied row-parallel sharding to %s", layer)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def partition_qkv_parallel_linear(
|
|
||||||
layer: torch.nn.Module, mesh: xs.Mesh
|
|
||||||
) -> torch.nn.Module:
|
|
||||||
assert isinstance(layer, QKVParallelLinear)
|
|
||||||
xla_layer = XlaQKVParallelLinear(layer, mesh)
|
|
||||||
logger.debug("Applied qkv parallel sharding to %s", layer)
|
|
||||||
return xla_layer
|
|
||||||
|
|
||||||
|
|
||||||
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict(
|
|
||||||
[
|
|
||||||
("QKVParallelLinear", partition_qkv_parallel_linear),
|
|
||||||
("ColumnParallelLinear", partition_column_parallel_linear),
|
|
||||||
("RowParallelLinear", partition_row_parallel_linear),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fqn(module):
|
|
||||||
# Get the fully qualified name of the module
|
|
||||||
return module.__class__.__qualname__
|
|
||||||
|
|
||||||
|
|
||||||
def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
|
|
||||||
"""
|
|
||||||
Recursively check a PyTorch model and apply appropriate sharding based on
|
|
||||||
the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: torch.nn.Module to process
|
|
||||||
mesh: An XLA SPMD mesh object used for sharding
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _process_module(module, name=None, parent=None):
|
|
||||||
for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
|
|
||||||
if get_fqn(module) == module_type:
|
|
||||||
wrapped_module = wrapping_func(module, mesh)
|
|
||||||
|
|
||||||
assert parent is not None and name is not None, (
|
|
||||||
"Top Level module is not expected to be wrapped."
|
|
||||||
)
|
|
||||||
if wrapped_module is not module:
|
|
||||||
# Wrapped module and module are different py object.
|
|
||||||
# The original module should be replaced by the
|
|
||||||
# wrapped_module.
|
|
||||||
logger.debug("replace %s with %s", module, wrapped_module)
|
|
||||||
setattr(parent, name, wrapped_module)
|
|
||||||
|
|
||||||
module = wrapped_module
|
|
||||||
break
|
|
||||||
|
|
||||||
for child_name, child_module in list(module.named_children()):
|
|
||||||
_process_module(child_module, child_name, module)
|
|
||||||
|
|
||||||
_process_module(model)
|
|
||||||
@ -1,6 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
|
||||||
|
|
||||||
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]
|
|
||||||
@ -1,141 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch_xla.core.xla_builder as xb
|
|
||||||
from torch.library import impl
|
|
||||||
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
|
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def bgmv_jax(inputs, loras, idxs):
|
|
||||||
return jnp.einsum(
|
|
||||||
"td,tX,Xld->tl",
|
|
||||||
inputs,
|
|
||||||
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
|
|
||||||
loras,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
|
|
||||||
|
|
||||||
|
|
||||||
@impl(XLA_LIB, "bgmv", "XLA")
|
|
||||||
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
|
||||||
if len(loras.shape) == 4:
|
|
||||||
loras = loras.squeeze(axis=1)
|
|
||||||
|
|
||||||
jax_import_guard()
|
|
||||||
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
|
|
||||||
|
|
||||||
|
|
||||||
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
|
|
||||||
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
|
||||||
T, _ = inputs.shape
|
|
||||||
if len(loras.shape) == 4:
|
|
||||||
loras = loras.squeeze(axis=1)
|
|
||||||
_, L, _ = loras.shape
|
|
||||||
|
|
||||||
return torch.empty((T, L), device=inputs.device)
|
|
||||||
|
|
||||||
|
|
||||||
def bgmv_expand(
|
|
||||||
inputs: torch.Tensor,
|
|
||||||
lora_b_weights: torch.Tensor,
|
|
||||||
output_tensor: torch.Tensor,
|
|
||||||
lora_indices_tensor: torch.Tensor,
|
|
||||||
add_inputs: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
|
||||||
|
|
||||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
|
||||||
[num_loras, lora_rank, hidden_size].
|
|
||||||
|
|
||||||
output_tensor (torch.Tensor): output tensor of shape
|
|
||||||
[num_tokens, hidden_size * num_slices].
|
|
||||||
|
|
||||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
|
||||||
indicating which LoRA matrix to use for each token.
|
|
||||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
|
||||||
tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
|
||||||
|
|
||||||
limit = output_tensor.shape[0]
|
|
||||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
|
||||||
limit = 1
|
|
||||||
|
|
||||||
if output_tensor.shape[1] > outputs.shape[1]:
|
|
||||||
outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
|
|
||||||
|
|
||||||
if add_inputs:
|
|
||||||
return output_tensor + outputs[:limit, : output_tensor.shape[1]]
|
|
||||||
else:
|
|
||||||
return outputs[:limit, : output_tensor.shape[1]]
|
|
||||||
|
|
||||||
|
|
||||||
def bgmv_shrink(
|
|
||||||
inputs: torch.Tensor,
|
|
||||||
lora_b_weights: torch.Tensor,
|
|
||||||
lora_indices_tensor: torch.Tensor,
|
|
||||||
scaling: float = 1.0,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
|
||||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
|
||||||
[num_loras, lora_rank, hidden_size].
|
|
||||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
|
||||||
indicating which LoRA matrix to use for each token.
|
|
||||||
scaling (float, optional): Scalar multiplier applied to the output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def bgmv_expand_slice(
|
|
||||||
inputs: torch.Tensor,
|
|
||||||
lora_b_weights: torch.Tensor,
|
|
||||||
output_tensor: torch.Tensor,
|
|
||||||
lora_indices_tensor: torch.Tensor,
|
|
||||||
slice_offset: int,
|
|
||||||
slice_size: int,
|
|
||||||
add_inputs: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
|
||||||
|
|
||||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
|
||||||
[num_loras, lora_rank, hidden_size].
|
|
||||||
|
|
||||||
output_tensor (torch.Tensor): output tensor of shape
|
|
||||||
[num_tokens, hidden_size * num_slices].
|
|
||||||
|
|
||||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
|
||||||
indicating which LoRA matrix to use for each token.
|
|
||||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
|
||||||
tensor.
|
|
||||||
"""
|
|
||||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
|
||||||
|
|
||||||
outputs = F.pad(
|
|
||||||
outputs,
|
|
||||||
(
|
|
||||||
slice_offset,
|
|
||||||
output_tensor.shape[1] - (slice_offset + slice_size),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if add_inputs:
|
|
||||||
return output_tensor + outputs
|
|
||||||
else:
|
|
||||||
return outputs
|
|
||||||
@ -1,358 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch_xla
|
|
||||||
|
|
||||||
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
|
||||||
from vllm.lora.punica_wrapper.utils import convert_mapping
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
# avoid circuit import
|
|
||||||
from vllm.lora.layers import LoRAMapping
|
|
||||||
|
|
||||||
from .punica_base import PunicaWrapperBase
|
|
||||||
|
|
||||||
|
|
||||||
class PunicaWrapperTPU(PunicaWrapperBase):
|
|
||||||
"""
|
|
||||||
PunicaWrapperTPU is designed to manage and provide metadata for the punica
|
|
||||||
kernel. The main function is to maintain the state information for
|
|
||||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
max_batches: int,
|
|
||||||
device: torch.device | str,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
|
||||||
|
|
||||||
# PunicaWrapperBase defines some tensors with dtype=torch.int64, which
|
|
||||||
# isn't supported by the TPU. So convert those tensors to int32.
|
|
||||||
# Not all of them are used by the TPU so only convert the useful ones.
|
|
||||||
self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32)
|
|
||||||
self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
|
|
||||||
self._sampler_indices_padded = self._sampler_indices_padded.to(
|
|
||||||
dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True)
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True)
|
|
||||||
|
|
||||||
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
|
|
||||||
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
|
|
||||||
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
|
|
||||||
|
|
||||||
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
|
|
||||||
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embeddings_indices(self) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This property provides access to the indices used for lora embeddings,
|
|
||||||
specifically for VocabParallelEmbeddingWithLoRA.
|
|
||||||
"""
|
|
||||||
return self._embeddings_indices[:]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sampler_indices_padded(self) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This property provides access to padded sampler indices.
|
|
||||||
"""
|
|
||||||
return self._sampler_indices_padded[:]
|
|
||||||
|
|
||||||
def shrink(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w_t_all: torch.Tensor,
|
|
||||||
scale: float,
|
|
||||||
):
|
|
||||||
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
|
|
||||||
|
|
||||||
def expand(
|
|
||||||
self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool
|
|
||||||
):
|
|
||||||
return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs)
|
|
||||||
|
|
||||||
def expand_slice(
|
|
||||||
self,
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w_t_all: torch.Tensor,
|
|
||||||
y_offset: int,
|
|
||||||
y_slice_size: int,
|
|
||||||
add_inputs: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return bgmv_expand_slice(
|
|
||||||
x,
|
|
||||||
w_t_all,
|
|
||||||
y,
|
|
||||||
self._get_token_lora_indices(x),
|
|
||||||
y_offset,
|
|
||||||
y_slice_size,
|
|
||||||
add_inputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_shrink(
|
|
||||||
self,
|
|
||||||
y: tuple[torch.Tensor, ...] | torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
|
||||||
scale: float,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor | None:
|
|
||||||
"""
|
|
||||||
Performs GEMM for multiple slices of lora_a.
|
|
||||||
|
|
||||||
Semantics:
|
|
||||||
for i in range(len(lora_a_stacked)):
|
|
||||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
|
||||||
x (torch.Tensor): Input tensor
|
|
||||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
|
|
||||||
scale (float): Scaling factor for the operation
|
|
||||||
"""
|
|
||||||
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(y, True)
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
|
|
||||||
for slice_idx in range(len(lora_a_stacked)):
|
|
||||||
lora_s = lora_a_stacked[slice_idx]
|
|
||||||
y_s = self.shrink(x, lora_s, scale)
|
|
||||||
y[slice_idx, :, :] = y_s # type: ignore[index]
|
|
||||||
return y
|
|
||||||
|
|
||||||
def add_expand(
|
|
||||||
self,
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: tuple[torch.Tensor, ...] | torch.Tensor,
|
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
|
||||||
output_slices: tuple[int, ...],
|
|
||||||
offset_start: int = 0,
|
|
||||||
add_inputs=True,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Performs GEMM for multiple slices of lora_b.
|
|
||||||
|
|
||||||
Semantics:
|
|
||||||
for i in range(len(lora_b_stacked)):
|
|
||||||
slice = output_slices[i]
|
|
||||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
|
||||||
offset += slice
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y (torch.Tensor): Output tensor.
|
|
||||||
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
|
||||||
output_slices (tuple[int, ...]): Every slice's size
|
|
||||||
add_inputs (bool): Defaults to True.
|
|
||||||
"""
|
|
||||||
y_org = y
|
|
||||||
y = y.view(-1, y.shape[-1])
|
|
||||||
offset_left = 0
|
|
||||||
|
|
||||||
for slice_idx in range(len(lora_b_stacked)):
|
|
||||||
y = self.expand_slice(
|
|
||||||
y,
|
|
||||||
x[slice_idx],
|
|
||||||
lora_b_stacked[slice_idx],
|
|
||||||
offset_left,
|
|
||||||
output_slices[slice_idx],
|
|
||||||
add_inputs=add_inputs,
|
|
||||||
)
|
|
||||||
offset_left += output_slices[slice_idx]
|
|
||||||
return y.view_as(y_org)
|
|
||||||
|
|
||||||
def add_lora_embedding(
|
|
||||||
self,
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
lora_b_stacked: torch.Tensor,
|
|
||||||
add_inputs: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
|
||||||
|
|
||||||
Semantics:
|
|
||||||
y += x @ lora_b_stacked
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y (torch.Tensor): Output tensor.
|
|
||||||
x (torch.Tensor): Input tensor.
|
|
||||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
|
||||||
add_inputs (bool): Default to True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Embedding layer only needs the expand op
|
|
||||||
return self.expand(y, x, lora_b_stacked, add_inputs)
|
|
||||||
|
|
||||||
def add_lora_linear(
|
|
||||||
self,
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
|
||||||
scale: float,
|
|
||||||
output_slices: tuple[int, ...],
|
|
||||||
*,
|
|
||||||
buffer: tuple[torch.Tensor, ...] | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Applicable to linear-related lora.
|
|
||||||
|
|
||||||
Semantics:
|
|
||||||
for i in range(len(lora_a_stacked)):
|
|
||||||
y[i] += (
|
|
||||||
x[i].unsqueeze(0)
|
|
||||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
|
||||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
|
||||||
* scale
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y (torch.Tensor): Output tensor. Will not be changed in-place.
|
|
||||||
x (torch.Tensor): Input tensor (T, E)
|
|
||||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
|
||||||
scale (float): Scaling factor.
|
|
||||||
output_slices (tuple[int, ...]): Every slice's size.
|
|
||||||
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
|
||||||
|
|
||||||
if buffer is None:
|
|
||||||
r = lora_b_stacked[0].size(-1)
|
|
||||||
T = x.size(0)
|
|
||||||
buffer = torch.zeros(
|
|
||||||
(len(output_slices), T, r),
|
|
||||||
dtype=x.dtype,
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
|
||||||
return self.add_expand(
|
|
||||||
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_lora_logits(
|
|
||||||
self,
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
lora_a_stacked: torch.Tensor,
|
|
||||||
lora_b_stacked: torch.Tensor,
|
|
||||||
scale,
|
|
||||||
*,
|
|
||||||
buffer: torch.Tensor | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
|
||||||
|
|
||||||
Semantics:
|
|
||||||
buffer = (x @ lora_a_stacked) * scale
|
|
||||||
y += buffer @ lora_b_stacked
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y (torch.Tensor): Output tensor.
|
|
||||||
x (torch.Tensor): Input tensor.
|
|
||||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
|
||||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
|
||||||
scale (float): Scaling factor.
|
|
||||||
buffer (Optional[torch.Tensor]):Default to None.
|
|
||||||
"""
|
|
||||||
y_org = y
|
|
||||||
y = y.view(-1, y.shape[-1])
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
|
|
||||||
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
|
|
||||||
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
|
|
||||||
y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
|
|
||||||
return y.view_as(y_org)
|
|
||||||
|
|
||||||
# This performs the same tensor ops as the base method, except it does them
|
|
||||||
# on the CPU then transfers the results to the TPU
|
|
||||||
def _update_base_metadata(
|
|
||||||
self,
|
|
||||||
mapping: "LoRAMapping",
|
|
||||||
lora_index_to_id: list[int | None],
|
|
||||||
max_loras: int,
|
|
||||||
vocab_size: int,
|
|
||||||
):
|
|
||||||
# Make sure we don't accidentally collect outside operations
|
|
||||||
torch_xla.sync()
|
|
||||||
|
|
||||||
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
|
||||||
# TODO: Should this happen inside mapping internally? If so how can we
|
|
||||||
# avoid having backend specific LoRAMapping classes?
|
|
||||||
mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping)
|
|
||||||
|
|
||||||
(
|
|
||||||
base_indices,
|
|
||||||
sampler_indices,
|
|
||||||
sampler_indices_padded,
|
|
||||||
embeddings_indices,
|
|
||||||
indices_len,
|
|
||||||
) = convert_mapping(
|
|
||||||
mapping,
|
|
||||||
lora_index_to_id,
|
|
||||||
max_loras,
|
|
||||||
vocab_size,
|
|
||||||
0, # extra_vocab_size
|
|
||||||
"cpu",
|
|
||||||
)
|
|
||||||
self._token_lora_indices = self._pad_to_shape(
|
|
||||||
base_indices, self._token_lora_indices.shape, dims=1
|
|
||||||
).to(self.device)
|
|
||||||
self._sampler_indices = self._pad_to_shape(
|
|
||||||
sampler_indices, self._sampler_indices.shape, dims=1
|
|
||||||
).to(self.device)
|
|
||||||
self._sampler_indices_padded = self._pad_to_shape(
|
|
||||||
sampler_indices_padded, self._sampler_indices_padded.shape, dims=1
|
|
||||||
).to(self.device)
|
|
||||||
self._embeddings_indices = self._pad_to_shape(
|
|
||||||
embeddings_indices, self._embeddings_indices.shape, dims=2
|
|
||||||
).to(self.device)
|
|
||||||
self.indices_len[:] = indices_len
|
|
||||||
|
|
||||||
def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
|
|
||||||
self.batch_size = 1
|
|
||||||
self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[
|
|
||||||
: self.batch_size
|
|
||||||
]
|
|
||||||
|
|
||||||
def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
|
|
||||||
num_reqs = len(prompt_mapping)
|
|
||||||
|
|
||||||
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
|
|
||||||
# import
|
|
||||||
MIN_NUM_SEQS = 8
|
|
||||||
|
|
||||||
padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
|
|
||||||
pad_len = padded_num_reqs - num_reqs
|
|
||||||
|
|
||||||
padding = [-1] * pad_len
|
|
||||||
return tuple(list(prompt_mapping) + padding)
|
|
||||||
|
|
||||||
def _pad_to_shape(self, src, target_shape, dims=1):
|
|
||||||
if dims == 1:
|
|
||||||
pad_len = target_shape[0] - src.shape[0]
|
|
||||||
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
|
|
||||||
else:
|
|
||||||
pad_rows = target_shape[0] - src.shape[0]
|
|
||||||
pad_cols = target_shape[1] - src.shape[1]
|
|
||||||
return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32)
|
|
||||||
@ -67,21 +67,15 @@ else:
|
|||||||
|
|
||||||
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
|
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
|
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
|
||||||
rocm_aiter_grouped_topk,
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_platform.is_tpu():
|
|
||||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
|
||||||
else:
|
|
||||||
fused_moe_pallas = None # type: ignore
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||||
FusedMoEModularMethod,
|
FusedMoEModularMethod,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||||
|
rocm_aiter_grouped_topk,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||||
UnquantizedFusedMoEMethod,
|
UnquantizedFusedMoEMethod,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,83 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute the histogram of an int32 tensor. The bin edges are defined by the
|
|
||||||
min and max values, with step = 1.
|
|
||||||
"""
|
|
||||||
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
|
|
||||||
assert min <= max, "min must be less than or equal to max."
|
|
||||||
|
|
||||||
def searchsorted(
|
|
||||||
sorted_sequence: torch.Tensor, values_to_search: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
|
|
||||||
|
|
||||||
bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to(
|
|
||||||
input.device
|
|
||||||
)
|
|
||||||
return searchsorted(bin_edges, input).to(torch.int32)
|
|
||||||
|
|
||||||
|
|
||||||
def fused_moe(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
global_num_experts: int,
|
|
||||||
expert_map: torch.Tensor = None,
|
|
||||||
renormalize: bool = False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states: [*, hidden_size]
|
|
||||||
w1: [num_experts, intermediate_size * 2, hidden_size]
|
|
||||||
w2: [num_experts, hidden_size, intermediate_size]
|
|
||||||
gating_output: [*, num_experts]
|
|
||||||
"""
|
|
||||||
assert expert_map is None, "expert_map is not supported for pallas MoE."
|
|
||||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
|
||||||
|
|
||||||
orig_shape = hidden_states.shape
|
|
||||||
hidden_size = hidden_states.shape[-1]
|
|
||||||
num_tokens = hidden_states.shape[:-1].numel()
|
|
||||||
num_experts = w1.shape[0]
|
|
||||||
intermediate_size = w2.shape[-1]
|
|
||||||
device = hidden_states.device
|
|
||||||
dtype = hidden_states.dtype
|
|
||||||
assert (num_tokens * topk) % 16 == 0, (
|
|
||||||
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
|
|
||||||
f"16 but got {num_tokens * topk}"
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
|
||||||
gating_output = gating_output.view(num_tokens, num_experts)
|
|
||||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
|
||||||
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
topk_weights = topk_weights.to(dtype)
|
|
||||||
|
|
||||||
topk_indices = topk_indices.flatten()
|
|
||||||
topk_argsort_indices = topk_indices.argsort()
|
|
||||||
topk_argsort_revert_indices = topk_argsort_indices.argsort()
|
|
||||||
token_indices = torch.arange(num_tokens, device=device).repeat_interleave(topk)
|
|
||||||
token_indices = token_indices[topk_argsort_indices]
|
|
||||||
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
|
|
||||||
|
|
||||||
x = hidden_states[token_indices]
|
|
||||||
x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True)
|
|
||||||
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
|
|
||||||
x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True)
|
|
||||||
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
|
||||||
|
|
||||||
x = x * topk_weights.unsqueeze(dim=-1)
|
|
||||||
x = x.sum(dim=-2)
|
|
||||||
x = x.reshape(orig_shape)
|
|
||||||
return x
|
|
||||||
@ -38,10 +38,6 @@ if current_platform.is_cuda_alike():
|
|||||||
else:
|
else:
|
||||||
TritonExperts = None # type: ignore
|
TritonExperts = None # type: ignore
|
||||||
|
|
||||||
if current_platform.is_tpu():
|
|
||||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
|
||||||
else:
|
|
||||||
fused_moe_pallas = None # type: ignore
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -403,53 +399,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
custom_routing_function=layer.custom_routing_function,
|
custom_routing_function=layer.custom_routing_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_tpu(
|
if current_platform.is_cpu():
|
||||||
self,
|
|
||||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert not layer.use_grouped_topk
|
|
||||||
assert layer.num_expert_group is None
|
|
||||||
assert layer.topk_group is None
|
|
||||||
assert layer.custom_routing_function is None
|
|
||||||
assert layer.apply_router_weight_on_input is False
|
|
||||||
if layer.scoring_func != "softmax":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Only softmax scoring function is supported for TPU."
|
|
||||||
)
|
|
||||||
if layer.e_score_correction_bias is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Expert score correction bias is not supported for TPU."
|
|
||||||
)
|
|
||||||
assert layer.activation == "silu", (
|
|
||||||
f"{layer.activation} is not supported for TPU."
|
|
||||||
)
|
|
||||||
assert layer.routed_scaling_factor == 1.0, (
|
|
||||||
f"routed_scaling_factor {layer.routed_scaling_factor} is "
|
|
||||||
"not supported for TPU."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
layer.enable_eplb is not False
|
|
||||||
or layer.expert_load_view is not None
|
|
||||||
or layer.logical_to_physical_map is not None
|
|
||||||
or layer.logical_replica_count is not None
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Expert load balancing is not supported for TPU.")
|
|
||||||
return fused_moe_pallas(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk=layer.top_k,
|
|
||||||
gating_output=router_logits,
|
|
||||||
global_num_experts=layer.global_num_experts,
|
|
||||||
expert_map=layer.expert_map,
|
|
||||||
renormalize=layer.renormalize,
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_platform.is_tpu():
|
|
||||||
forward_native = forward_tpu
|
|
||||||
elif current_platform.is_cpu():
|
|
||||||
forward_native = forward_cpu
|
forward_native = forward_cpu
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
forward_native = forward_xpu
|
forward_native = forward_xpu
|
||||||
|
|||||||
@ -11,7 +11,6 @@ logger = init_logger(__name__)
|
|||||||
QuantizationMethods = Literal[
|
QuantizationMethods = Literal[
|
||||||
"awq",
|
"awq",
|
||||||
"deepspeedfp",
|
"deepspeedfp",
|
||||||
"tpu_int8",
|
|
||||||
"fp8",
|
"fp8",
|
||||||
"ptpc_fp8",
|
"ptpc_fp8",
|
||||||
"fbgemm_fp8",
|
"fbgemm_fp8",
|
||||||
@ -130,12 +129,10 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
|||||||
from .ptpc_fp8 import PTPCFp8Config
|
from .ptpc_fp8 import PTPCFp8Config
|
||||||
from .rtn import RTNConfig
|
from .rtn import RTNConfig
|
||||||
from .torchao import TorchAOConfig
|
from .torchao import TorchAOConfig
|
||||||
from .tpu_int8 import Int8TpuConfig
|
|
||||||
|
|
||||||
method_to_config: dict[str, type[QuantizationConfig]] = {
|
method_to_config: dict[str, type[QuantizationConfig]] = {
|
||||||
"awq": AWQConfig,
|
"awq": AWQConfig,
|
||||||
"deepspeedfp": DeepSpeedFPConfig,
|
"deepspeedfp": DeepSpeedFPConfig,
|
||||||
"tpu_int8": Int8TpuConfig,
|
|
||||||
"fp8": Fp8Config,
|
"fp8": Fp8Config,
|
||||||
"fbgemm_fp8": FBGEMMFp8Config,
|
"fbgemm_fp8": FBGEMMFp8Config,
|
||||||
"fp_quant": FPQuantConfig,
|
"fp_quant": FPQuantConfig,
|
||||||
|
|||||||
@ -19,9 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
|||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||||
TritonScaledMMLinearKernel,
|
TritonScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
|
||||||
XLAScaledMMLinearKernel,
|
|
||||||
)
|
|
||||||
from vllm.platforms import PlatformEnum, current_platform
|
from vllm.platforms import PlatformEnum, current_platform
|
||||||
|
|
||||||
# in priority/performance order (when available)
|
# in priority/performance order (when available)
|
||||||
@ -29,7 +26,6 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
|||||||
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
||||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,106 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from functorch.experimental.control_flow import cond # noqa: F401
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
||||||
convert_to_channelwise,
|
|
||||||
)
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
|
||||||
|
|
||||||
|
|
||||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
|
||||||
@classmethod
|
|
||||||
def is_supported(
|
|
||||||
cls, compute_capability: int | None = None
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
if not current_platform.is_tpu():
|
|
||||||
return False, "Requires TPU."
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
|
||||||
if not current_platform.is_tpu():
|
|
||||||
return False, "ScaledMMXLA requires running on TPU."
|
|
||||||
|
|
||||||
if c.is_static_input_scheme:
|
|
||||||
return False, "ScaledMMXLA requires dynamic activation scales."
|
|
||||||
|
|
||||||
if not c.input_symmetric:
|
|
||||||
return False, "ScaledMMXLA requires symmetric activation scales."
|
|
||||||
|
|
||||||
if not c.is_channelwise:
|
|
||||||
return False, "ScaledMMXLA requires channelwise weight scales"
|
|
||||||
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
# WEIGHT
|
|
||||||
# [out, in] (different than cutlass_scaled_mm)
|
|
||||||
weight = getattr(layer, self.w_q_name)
|
|
||||||
replace_parameter(
|
|
||||||
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
# WEIGHT SCALE
|
|
||||||
# XLA kernels support only per-tensor and per-channel.
|
|
||||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
|
||||||
# scales being passed to the kernel), convert to the per-channel case.
|
|
||||||
is_fused_module = len(layer.logical_widths) > 1
|
|
||||||
weight_scale = getattr(layer, self.w_s_name)
|
|
||||||
if is_fused_module and not self.config.is_channelwise:
|
|
||||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
|
||||||
|
|
||||||
# [out_channel,] (different than cutlass_scaled_mm)
|
|
||||||
weight_scale = weight_scale.squeeze(-1)
|
|
||||||
replace_parameter(
|
|
||||||
layer,
|
|
||||||
self.w_s_name,
|
|
||||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only support symmetric dynamic activation quantization.
|
|
||||||
setattr(layer, self.i_s_name, None)
|
|
||||||
setattr(layer, self.i_zp_name, None)
|
|
||||||
setattr(layer, self.azp_adj_name, None)
|
|
||||||
|
|
||||||
# Filter warning for cond usage in apply_weights. It is okay
|
|
||||||
# to specialize the graph since bias is not dynamic.
|
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore",
|
|
||||||
message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501
|
|
||||||
)
|
|
||||||
|
|
||||||
def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
|
|
||||||
return x
|
|
||||||
|
|
||||||
def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
|
|
||||||
return x + bias
|
|
||||||
|
|
||||||
def apply_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
|
||||||
|
|
||||||
# Required to register custom ops.
|
|
||||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
|
||||||
|
|
||||||
out = torch.ops.xla.quantized_matmul_int8(
|
|
||||||
x,
|
|
||||||
w_q,
|
|
||||||
w_s,
|
|
||||||
quantize_activation=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Explicitly capture control flow to make dynamo happy.
|
|
||||||
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
|
||||||
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
|
||||||
@ -1,139 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import Module
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
||||||
from vllm.model_executor.layers.quantization import (
|
|
||||||
QuantizationConfig,
|
|
||||||
QuantizationMethods,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.parameter import ModelWeightParameter
|
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["none", "dynamic"]
|
|
||||||
|
|
||||||
|
|
||||||
class Int8TpuConfig(QuantizationConfig):
|
|
||||||
"""Int8 Quantization Config class for TPU Backend."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
activation_scheme: str = "none",
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
if activation_scheme not in ACTIVATION_SCHEMES:
|
|
||||||
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
|
||||||
self.activation_scheme = activation_scheme
|
|
||||||
|
|
||||||
def get_name(self) -> QuantizationMethods:
|
|
||||||
return "tpu_int8"
|
|
||||||
|
|
||||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
||||||
return [torch.float16, torch.bfloat16]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_min_capability(cls) -> int:
|
|
||||||
raise NotImplementedError("This function should not be called with TPU Backend")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_config_filenames() -> list[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
|
|
||||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
|
||||||
return cls(activation_scheme=activation_scheme)
|
|
||||||
|
|
||||||
def get_quant_method(
|
|
||||||
self, layer: Module, prefix: str
|
|
||||||
) -> Optional["TPUInt8LinearMethod"]:
|
|
||||||
if isinstance(layer, LinearBase):
|
|
||||||
return TPUInt8LinearMethod(self)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class TPUInt8LinearMethod(LinearMethodBase):
|
|
||||||
"""Int8 Linear method for TPU Quant."""
|
|
||||||
|
|
||||||
def __init__(self, quant_config: Int8TpuConfig):
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.quantize_activation = False
|
|
||||||
if self.quant_config.activation_scheme == "dynamic":
|
|
||||||
self.quantize_activation = True
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: Module,
|
|
||||||
input_size_per_partition: int,
|
|
||||||
output_partition_sizes: list[int],
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
|
||||||
weight = ModelWeightParameter(
|
|
||||||
data=torch.empty(
|
|
||||||
sum(output_partition_sizes),
|
|
||||||
input_size_per_partition,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
input_dim=1,
|
|
||||||
output_dim=0,
|
|
||||||
weight_loader=weight_loader,
|
|
||||||
)
|
|
||||||
layer.register_parameter("weight", weight)
|
|
||||||
|
|
||||||
def _quantize_weight(
|
|
||||||
self, weight: torch.Tensor
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
weight_dtype = weight.dtype
|
|
||||||
weight = weight.cpu().to(torch.float32)
|
|
||||||
n_bit = 8
|
|
||||||
eps = 1e-5
|
|
||||||
max_int = 2 ** (n_bit - 1) - 1
|
|
||||||
min_int = -(2 ** (n_bit - 1))
|
|
||||||
max_val = weight.abs().amax(dim=-1, keepdim=True)
|
|
||||||
max_val = max_val.clamp(min=eps)
|
|
||||||
qscale = max_val / max_int
|
|
||||||
qweight = torch.clamp(
|
|
||||||
torch.round(weight * (1.0 / qscale)), min_int, max_int
|
|
||||||
).to(torch.int8)
|
|
||||||
qscale = qscale.squeeze().to(weight_dtype)
|
|
||||||
return qweight, qscale
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
|
||||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
|
||||||
device = layer.weight.device
|
|
||||||
qweight, qscale = self._quantize_weight(layer.weight)
|
|
||||||
qweight = qweight.to(device)
|
|
||||||
qscale = qscale.to(device)
|
|
||||||
layer.weight = Parameter(qweight, requires_grad=False)
|
|
||||||
layer.scale = Parameter(qscale, requires_grad=False)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
try:
|
|
||||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
|
||||||
except ImportError as err:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install torch_xla by following the instructions at "
|
|
||||||
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
|
|
||||||
"to run vLLM on TPU."
|
|
||||||
) from err
|
|
||||||
weight = layer.weight
|
|
||||||
scale = layer.scale
|
|
||||||
out = torch.ops.xla.quantized_matmul_int8(
|
|
||||||
x, weight, scale, quantize_activation=self.quantize_activation
|
|
||||||
)
|
|
||||||
if bias is not None:
|
|
||||||
out = out + bias
|
|
||||||
return out
|
|
||||||
@ -30,7 +30,6 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
pt_weights_iterator,
|
pt_weights_iterator,
|
||||||
safetensors_weights_iterator,
|
safetensors_weights_iterator,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -241,22 +240,6 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
self.load_config.pt_load_map_location,
|
self.load_config.pt_load_map_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_platform.is_tpu():
|
|
||||||
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
|
||||||
|
|
||||||
if not USE_TPU_INFERENCE:
|
|
||||||
# In PyTorch XLA, we should call `torch_xla.sync`
|
|
||||||
# frequently so that not too many ops are accumulated
|
|
||||||
# in the XLA program.
|
|
||||||
import torch_xla
|
|
||||||
|
|
||||||
def _xla_weights_iterator(iterator: Generator):
|
|
||||||
for weights in iterator:
|
|
||||||
yield weights
|
|
||||||
torch_xla.sync(wait=False)
|
|
||||||
|
|
||||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
|
||||||
|
|
||||||
if self.counter_before_loading_weights == 0.0:
|
if self.counter_before_loading_weights == 0.0:
|
||||||
self.counter_before_loading_weights = time.perf_counter()
|
self.counter_before_loading_weights = time.perf_counter()
|
||||||
# Apply the prefix.
|
# Apply the prefix.
|
||||||
|
|||||||
@ -1,118 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
import torch_xla.distributed.spmd as xs
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
|
||||||
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
|
||||||
from vllm.model_executor.model_loader.utils import (
|
|
||||||
initialize_model,
|
|
||||||
process_weights_after_loading,
|
|
||||||
)
|
|
||||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TPUModelLoader(DefaultModelLoader):
|
|
||||||
"""
|
|
||||||
A TPU model loader for model loading under SPMD mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def load_model(
|
|
||||||
self,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
mesh: xs.Mesh | None = None,
|
|
||||||
) -> nn.Module:
|
|
||||||
# Initialize model and load weights on CPU. Then, during SPMD partition,
|
|
||||||
# weights are sharded and transferred to TPUs.
|
|
||||||
self.counter_before_loading_weights = time.perf_counter()
|
|
||||||
model_config = vllm_config.model_config
|
|
||||||
assert model_config.quantization is None, "Quantization not supported"
|
|
||||||
target_device = torch.device("cpu")
|
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
|
||||||
with target_device:
|
|
||||||
model = initialize_model(vllm_config=vllm_config)
|
|
||||||
|
|
||||||
load_format = vllm_config.load_config.load_format
|
|
||||||
if load_format != "dummy":
|
|
||||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
|
||||||
all_weights = self.get_all_weights(model_config, model)
|
|
||||||
loaded_weights = model.load_weights(all_weights)
|
|
||||||
self.counter_after_loading_weights = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
"Loading weights took %.2f seconds",
|
|
||||||
self.counter_after_loading_weights
|
|
||||||
- self.counter_before_loading_weights,
|
|
||||||
)
|
|
||||||
# We only enable strict check for non-quantized models
|
|
||||||
# that have loaded weights tracking currently.
|
|
||||||
if model_config.quantization is None and loaded_weights is not None:
|
|
||||||
weights_not_loaded = weights_to_load - loaded_weights
|
|
||||||
if weights_not_loaded:
|
|
||||||
raise ValueError(
|
|
||||||
"Following weights were not initialized from "
|
|
||||||
f"checkpoint: {weights_not_loaded}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Use dummy weight during weight loading.")
|
|
||||||
|
|
||||||
process_weights_after_loading(model, model_config, target_device)
|
|
||||||
|
|
||||||
counter_before_partition = time.perf_counter()
|
|
||||||
model = model.eval()
|
|
||||||
model = model.to("xla")
|
|
||||||
shard_model(model, mesh)
|
|
||||||
counter_after_partition = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
"Partition model took %.2f seconds",
|
|
||||||
counter_after_partition - counter_before_partition,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure the model is properly loaded.
|
|
||||||
self._check_model_is_loaded(mesh, model)
|
|
||||||
|
|
||||||
# Need to torch compile after model sharding are done. Because the
|
|
||||||
# compiler hints ('xs.mark_sharding') are torch ops.
|
|
||||||
if not model_config.is_multimodal_model:
|
|
||||||
model.model = torch.compile(model.model, backend="openxla")
|
|
||||||
else:
|
|
||||||
model.language_model.model = torch.compile(
|
|
||||||
model.language_model.model, backend="openxla"
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None:
|
|
||||||
"""
|
|
||||||
Ensure the model is properly loaded.
|
|
||||||
1. All model parameters and buffers are on XLA device.
|
|
||||||
2. Non-SPMD friendly layers are replaced as expected.
|
|
||||||
"""
|
|
||||||
device = xm.xla_device()
|
|
||||||
device_type = str(device.type)
|
|
||||||
|
|
||||||
# Check parameters
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
assert param.device.type == device_type, (
|
|
||||||
f"Parameter {name} is on {param.device.type} instead of {device_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check buffers
|
|
||||||
for name, buffer in model.named_buffers():
|
|
||||||
assert buffer.device.type == device_type, (
|
|
||||||
f"Buffer {name} is on {buffer.device.type} instead of {device_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for module in model.modules():
|
|
||||||
if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"):
|
|
||||||
raise AssertionError(
|
|
||||||
"QKVParallelLinear should be replaced by \
|
|
||||||
XlaQKVParallelLinear under SPMD mode."
|
|
||||||
)
|
|
||||||
@ -1,287 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import contextlib
|
|
||||||
from typing import TYPE_CHECKING, Optional, cast
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from tpu_info import device
|
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
from vllm.inputs import ProcessorInputs, PromptType
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
from vllm.attention.selector import AttentionSelectorConfig
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.config.cache import BlockSize
|
|
||||||
from vllm.pooling_params import PoolingParams
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
|
|
||||||
ParamsType: TypeAlias = SamplingParams | PoolingParams
|
|
||||||
else:
|
|
||||||
BlockSize = None
|
|
||||||
VllmConfig = None
|
|
||||||
PoolingParams = None
|
|
||||||
ParamsType = None
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
USE_TPU_INFERENCE = False
|
|
||||||
|
|
||||||
|
|
||||||
class TpuPlatform(Platform):
|
|
||||||
_enum = PlatformEnum.TPU
|
|
||||||
device_name: str = "tpu"
|
|
||||||
device_type: str = "tpu"
|
|
||||||
dispatch_key: str = "XLA"
|
|
||||||
ray_device_key: str = "TPU"
|
|
||||||
dist_backend: str = "gloo"
|
|
||||||
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
|
||||||
simple_compile_backend: str = "openxla"
|
|
||||||
|
|
||||||
supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
|
|
||||||
|
|
||||||
additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def import_kernels(cls) -> None:
|
|
||||||
# Do not import vllm._C
|
|
||||||
with contextlib.suppress(ImportError):
|
|
||||||
import vllm._moe_C # noqa: F401
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_attn_backend_cls(
|
|
||||||
cls,
|
|
||||||
selected_backend: "AttentionBackendEnum",
|
|
||||||
attn_selector_config: "AttentionSelectorConfig",
|
|
||||||
) -> str:
|
|
||||||
if attn_selector_config.use_sparse:
|
|
||||||
raise NotImplementedError("Sparse Attention is not supported on TPU.")
|
|
||||||
if selected_backend != AttentionBackendEnum.PALLAS:
|
|
||||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
||||||
|
|
||||||
logger.info("Using Pallas V1 backend.")
|
|
||||||
return AttentionBackendEnum.PALLAS.get_path()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
|
||||||
return [
|
|
||||||
AttentionBackendEnum.PALLAS,
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_vit_attn_backend(
|
|
||||||
cls,
|
|
||||||
head_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
backend: Optional["AttentionBackendEnum"] = None,
|
|
||||||
) -> "AttentionBackendEnum":
|
|
||||||
if backend is not None:
|
|
||||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
|
||||||
f"Backend {backend} is not supported for vit attention"
|
|
||||||
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
|
|
||||||
)
|
|
||||||
logger.info_once(f"Using backend {backend} for vit attention.")
|
|
||||||
return backend
|
|
||||||
|
|
||||||
logger.info_once(
|
|
||||||
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
|
|
||||||
)
|
|
||||||
return AttentionBackendEnum.PALLAS
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def set_device(cls, device: torch.device) -> None:
|
|
||||||
"""
|
|
||||||
Set the device for the current platform.
|
|
||||||
"""
|
|
||||||
torch.tpu.set_device(device)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
|
||||||
chip_type, _ = device.get_local_chips()
|
|
||||||
return f"TPU {chip_type.name}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_punica_wrapper(cls) -> str:
|
|
||||||
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
|
|
||||||
return torch.finfo(dtype).min, torch.finfo(dtype).max
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def can_update_inplace(cls):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_lora_vocab_padding_size(cls) -> int:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def inference_mode(cls):
|
|
||||||
return torch.no_grad()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
|
||||||
from vllm.config import CompilationMode, CUDAGraphMode
|
|
||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
# For v0, the default block size is 16.
|
|
||||||
if cache_config and cache_config.block_size is None:
|
|
||||||
cache_config.block_size = cast(BlockSize, 16)
|
|
||||||
compilation_config = vllm_config.compilation_config
|
|
||||||
|
|
||||||
# TPU only supports DYNAMO_TRACE_ONCE compilation mode
|
|
||||||
if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
|
|
||||||
logger.info(
|
|
||||||
"[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\
|
|
||||||
disabling cudagraph."
|
|
||||||
)
|
|
||||||
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
|
||||||
|
|
||||||
if (
|
|
||||||
compilation_config.cudagraph_mode is None
|
|
||||||
or compilation_config.cudagraph_mode.max_cudagraph_mode()
|
|
||||||
!= CUDAGraphMode.NONE
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
|
|
||||||
)
|
|
||||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
|
||||||
|
|
||||||
if compilation_config.backend == "":
|
|
||||||
compilation_config.backend = "openxla"
|
|
||||||
|
|
||||||
assert vllm_config.speculative_config is None, (
|
|
||||||
"TPU does not support speculative decoding"
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = vllm_config.model_config
|
|
||||||
if model_config is not None and model_config.dtype in (
|
|
||||||
torch.float16,
|
|
||||||
torch.float32,
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"The TPU backend currently does not support %s. "
|
|
||||||
"Using bfloat16 instead.",
|
|
||||||
model_config.dtype,
|
|
||||||
)
|
|
||||||
model_config.dtype = torch.bfloat16
|
|
||||||
|
|
||||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
|
||||||
|
|
||||||
cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment]
|
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
|
||||||
scheduler_config = vllm_config.scheduler_config
|
|
||||||
if parallel_config.worker_cls == "auto":
|
|
||||||
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
|
|
||||||
|
|
||||||
assert not vllm_config.speculative_config, (
|
|
||||||
"Speculative decoding is not yet supported for TPU backend"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
scheduler_config.is_multimodal_model
|
|
||||||
and not scheduler_config.disable_chunked_mm_input
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"TPU does not support running Multimodal models"
|
|
||||||
" without setting `--disable_chunked_mm_input`. "
|
|
||||||
"Forcing --disable_chunked_mm_input."
|
|
||||||
)
|
|
||||||
scheduler_config.disable_chunked_mm_input = True
|
|
||||||
|
|
||||||
if model_config and model_config.use_mla:
|
|
||||||
logger.info(
|
|
||||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
|
||||||
"prefill and prefix caching to be disabled."
|
|
||||||
)
|
|
||||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
|
||||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
|
||||||
vllm_config.model_config.max_model_len,
|
|
||||||
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_pin_memory_available(cls):
|
|
||||||
logger.warning("Pin memory is not supported on TPU.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_device_communicator_cls(cls) -> str:
|
|
||||||
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_request(
|
|
||||||
cls,
|
|
||||||
prompt: PromptType,
|
|
||||||
params: ParamsType,
|
|
||||||
processed_inputs: ProcessorInputs,
|
|
||||||
) -> None:
|
|
||||||
"""Raises if this request is unsupported on this platform"""
|
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(params, SamplingParams)
|
|
||||||
and params.sampling_type == SamplingType.RANDOM_SEED
|
|
||||||
):
|
|
||||||
raise ValueError("Torch XLA does not support per-request seed.")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@torch.compile(backend="openxla")
|
|
||||||
def insert_blocks_to_device(
|
|
||||||
cls,
|
|
||||||
src_cache: torch.Tensor,
|
|
||||||
dst_cache: torch.Tensor,
|
|
||||||
src_block_indices: torch.Tensor,
|
|
||||||
dst_block_indices: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
|
|
||||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@torch.compile(backend="openxla")
|
|
||||||
def swap_out_blocks_to_host(
|
|
||||||
cls,
|
|
||||||
src_cache: torch.Tensor,
|
|
||||||
dst_cache: torch.Tensor,
|
|
||||||
src_block_indices: torch.Tensor,
|
|
||||||
dst_block_indices: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
"""tpu blocks to cpu blocks"""
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
|
|
||||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def use_sync_weight_loader(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def check_max_model_len(cls, max_model_len: int) -> int:
|
|
||||||
"""
|
|
||||||
Check max_model_len for the current platform.
|
|
||||||
"""
|
|
||||||
logger.warning(
|
|
||||||
"--max-model-len is not specified, "
|
|
||||||
"it's currently using model's default length %d, "
|
|
||||||
"which might be too large."
|
|
||||||
"Please input with --max-model-len based on your "
|
|
||||||
"request input length and output length, to avoid "
|
|
||||||
"unnecessary degradation.",
|
|
||||||
max_model_len,
|
|
||||||
)
|
|
||||||
return max_model_len
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tpu_inference.platforms import (
|
from tpu_inference.platforms import (
|
||||||
@ -291,5 +14,7 @@ try:
|
|||||||
TpuPlatform = TpuInferencePlatform # type: ignore
|
TpuPlatform = TpuInferencePlatform # type: ignore
|
||||||
USE_TPU_INFERENCE = True
|
USE_TPU_INFERENCE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info("tpu_inference not found, using vLLM's TpuPlatform")
|
logger.error(
|
||||||
|
"tpu_inference not found, please install tpu_inference to run vllm on TPU"
|
||||||
|
)
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -186,20 +186,6 @@ class UsageMessage:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _report_torch_xla_usage(self) -> bool:
|
|
||||||
try:
|
|
||||||
import torch_xla
|
|
||||||
|
|
||||||
self.gpu_count = torch_xla.runtime.world_size()
|
|
||||||
self.gpu_type = torch_xla.tpu.get_tpu_type()
|
|
||||||
self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[
|
|
||||||
"bytes_limit"
|
|
||||||
]
|
|
||||||
self.cuda_runtime = "torch_xla"
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _report_usage_once(
|
def _report_usage_once(
|
||||||
self,
|
self,
|
||||||
model_architecture: str,
|
model_architecture: str,
|
||||||
@ -217,9 +203,7 @@ class UsageMessage:
|
|||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
self.cuda_runtime = torch.version.cuda
|
self.cuda_runtime = torch.version.cuda
|
||||||
if current_platform.is_tpu(): # noqa: SIM102
|
if current_platform.is_tpu(): # noqa: SIM102
|
||||||
if (not self._report_tpu_inference_usage()) and (
|
if not self._report_tpu_inference_usage():
|
||||||
not self._report_torch_xla_usage()
|
|
||||||
):
|
|
||||||
logger.exception("Failed to collect TPU information")
|
logger.exception("Failed to collect TPU information")
|
||||||
self.provider = _detect_cloud_provider()
|
self.provider = _detect_cloud_provider()
|
||||||
self.architecture = platform.machine()
|
self.architecture = platform.machine()
|
||||||
|
|||||||
@ -1,436 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (
|
|
||||||
AttentionBackend,
|
|
||||||
AttentionImpl,
|
|
||||||
AttentionLayer,
|
|
||||||
AttentionType,
|
|
||||||
)
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.utils.math_utils import cdiv, next_power_of_2
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
# TPU requires the head size to be a multiple of 128.
|
|
||||||
TPU_HEAD_SIZE_ALIGNMENT = 128
|
|
||||||
|
|
||||||
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
|
|
||||||
# from to fp32 directly. That's why it has a dtype mapping different from GPU
|
|
||||||
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
||||||
"half": torch.half,
|
|
||||||
"bfloat16": torch.bfloat16,
|
|
||||||
"float": torch.float,
|
|
||||||
"fp8": torch.float8_e4m3fn,
|
|
||||||
"fp8_e4m3": torch.float8_e4m3fn,
|
|
||||||
"fp8_e5m2": torch.float8_e5m2,
|
|
||||||
"int8": torch.int8,
|
|
||||||
"uint8": torch.uint8,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
import tpu_inference # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
# Lazy import torch_xla
|
|
||||||
import torch_xla.core.xla_builder as xb
|
|
||||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
|
||||||
from torch.library import impl
|
|
||||||
from torch_xla._internal.jax_workarounds import requires_jax
|
|
||||||
from torch_xla.experimental.custom_kernel import XLA_LIB
|
|
||||||
|
|
||||||
@requires_jax
|
|
||||||
def kv_cache_update_op_impl(
|
|
||||||
kv: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
num_kv_update_slices: torch.Tensor,
|
|
||||||
page_size: int,
|
|
||||||
num_slices_per_block: int,
|
|
||||||
):
|
|
||||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
|
||||||
|
|
||||||
new_kv_cache = xb.call_jax(
|
|
||||||
kv_cache_update,
|
|
||||||
(kv, slot_mapping, kv_cache, num_kv_update_slices),
|
|
||||||
{"page_size": page_size, "num_slices_per_block": num_slices_per_block},
|
|
||||||
)
|
|
||||||
return new_kv_cache
|
|
||||||
|
|
||||||
XLA_LIB.define(
|
|
||||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping,"
|
|
||||||
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size,"
|
|
||||||
"int num_slices_per_block)"
|
|
||||||
"-> Tensor",
|
|
||||||
)
|
|
||||||
|
|
||||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
|
||||||
def kv_cache_update_op_xla(
|
|
||||||
kv: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
num_kv_update_slices: torch.Tensor,
|
|
||||||
page_size: int,
|
|
||||||
num_slices_per_block: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
new_kv_cache = kv_cache_update_op_impl(
|
|
||||||
kv,
|
|
||||||
slot_mapping,
|
|
||||||
kv_cache,
|
|
||||||
num_kv_update_slices,
|
|
||||||
page_size,
|
|
||||||
num_slices_per_block,
|
|
||||||
)
|
|
||||||
return new_kv_cache
|
|
||||||
|
|
||||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
|
||||||
def kv_cache_update_op_non_xla(
|
|
||||||
kv: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
num_kv_update_slices: torch.Tensor,
|
|
||||||
page_size: int,
|
|
||||||
num_slices_per_block: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return kv_cache
|
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionBackend(AttentionBackend):
|
|
||||||
@staticmethod
|
|
||||||
def get_name() -> str:
|
|
||||||
return "PALLAS"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
|
||||||
return PallasAttentionBackendImpl
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_kv_cache_shape(
|
|
||||||
num_blocks: int,
|
|
||||||
block_size: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
cache_dtype_str: str = "auto",
|
|
||||||
) -> tuple[int, ...]:
|
|
||||||
padded_head_size = (
|
|
||||||
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
|
||||||
)
|
|
||||||
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def swap_blocks(
|
|
||||||
src_kv_cache: torch.Tensor,
|
|
||||||
dst_kv_cache: torch.Tensor,
|
|
||||||
src_to_dst: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
|
||||||
|
|
||||||
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
|
|
||||||
# block_tables within the PallasMetadata constitute almost the entire SMEM
|
|
||||||
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
|
|
||||||
# we simply make sure that the size is smaller than half of SMEM capacity.
|
|
||||||
@staticmethod
|
|
||||||
def get_min_page_size(vllm_config: VllmConfig) -> int:
|
|
||||||
max_num_page_per_req = (
|
|
||||||
1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4
|
|
||||||
)
|
|
||||||
min_page_size = cdiv(
|
|
||||||
vllm_config.model_config.max_model_len, max_num_page_per_req
|
|
||||||
)
|
|
||||||
min_page_size = 1 << (min_page_size - 1).bit_length()
|
|
||||||
return min_page_size
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_max_num_seqs(model_len: int, page_size: int) -> int:
|
|
||||||
num_page_per_req = cdiv(model_len, page_size)
|
|
||||||
return 1024 * 1024 // 2 // num_page_per_req // 4
|
|
||||||
|
|
||||||
# TPU has limited SREGs (scalar registers), if page_size is too small, we
|
|
||||||
# can spill SREGs easily which leads to bad performance. The strategy we
|
|
||||||
# apply here is trying to split max-model-len to 16 pages which make the
|
|
||||||
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
|
|
||||||
@staticmethod
|
|
||||||
def get_page_size(vllm_config: VllmConfig) -> int:
|
|
||||||
# TODO: This is a temporary fix for vmem OOM.
|
|
||||||
# For long model length, we use 16 page-size to avoid too much
|
|
||||||
# VMEM spill. A more robust solution should be implemented to
|
|
||||||
# handle VREG spills.
|
|
||||||
if vllm_config.model_config.max_model_len > 8192:
|
|
||||||
return 16
|
|
||||||
page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16
|
|
||||||
if page_size <= 16:
|
|
||||||
return 16
|
|
||||||
if page_size >= 256:
|
|
||||||
return 256
|
|
||||||
return page_size
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PallasMetadata:
|
|
||||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
|
||||||
# |---------- N-1 iteration --------|
|
|
||||||
# |---------------- N iteration ---------------------|
|
|
||||||
# |- tokenA -|......................|-- newTokens ---|
|
|
||||||
# |---------- context_len ----------|
|
|
||||||
# |-------------------- seq_len ---------------------|
|
|
||||||
# |-- query_len ---|
|
|
||||||
|
|
||||||
# Used in the PallasAttentionBackendImpl
|
|
||||||
slot_mapping: torch.Tensor
|
|
||||||
block_tables: torch.Tensor
|
|
||||||
context_lens: torch.Tensor
|
|
||||||
query_start_loc: torch.Tensor
|
|
||||||
num_seqs: torch.Tensor
|
|
||||||
num_kv_update_slices: torch.Tensor
|
|
||||||
num_slices_per_kv_cache_update_block: int
|
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionBackendImpl(AttentionImpl):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
scale: float,
|
|
||||||
num_kv_heads: int,
|
|
||||||
alibi_slopes: list[float] | None,
|
|
||||||
sliding_window: int | None,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
logits_soft_cap: float | None = None,
|
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
kv_sharing_target_layer_name: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_size = head_size
|
|
||||||
self.scale = float(scale)
|
|
||||||
self.num_kv_heads = num_kv_heads
|
|
||||||
self.sliding_window = sliding_window
|
|
||||||
self.logits_soft_cap = logits_soft_cap
|
|
||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
|
||||||
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
||||||
if alibi_slopes is not None:
|
|
||||||
raise NotImplementedError("Alibi slopes is not supported.")
|
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"PallasAttentionBackendImpl"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.kv_cache_quantized_dtype = None
|
|
||||||
if kv_cache_dtype != "auto":
|
|
||||||
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
|
|
||||||
kv_cache_dtype.lower().strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
layer: AttentionLayer,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: PallasMetadata,
|
|
||||||
output: torch.Tensor | None = None,
|
|
||||||
output_scale: torch.Tensor | None = None,
|
|
||||||
output_block_scale: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Forward pass with Pallas attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
|
||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
|
||||||
kv_cache: shape =
|
|
||||||
[num_blocks, block_size, num_kv_heads * 2, head_size]
|
|
||||||
attn_metadata: Metadata for attention.
|
|
||||||
Returns:
|
|
||||||
shape = [num_tokens, num_heads * head_size]
|
|
||||||
"""
|
|
||||||
if output_scale is not None or output_block_scale is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"fused output quantization is not yet supported"
|
|
||||||
" for PallasAttentionBackendImpl"
|
|
||||||
)
|
|
||||||
|
|
||||||
# For determine_available_memory case.
|
|
||||||
if kv_cache.numel() == 0:
|
|
||||||
if output is None:
|
|
||||||
output = torch.ones_like(query)
|
|
||||||
return output
|
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
|
||||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
|
||||||
padded_head_size = (
|
|
||||||
cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
|
||||||
)
|
|
||||||
query = torch.nn.functional.pad(
|
|
||||||
query, (0, padded_head_size - self.head_size), value=0.0
|
|
||||||
)
|
|
||||||
key = torch.nn.functional.pad(
|
|
||||||
key, (0, padded_head_size - self.head_size), value=0.0
|
|
||||||
)
|
|
||||||
value = torch.nn.functional.pad(
|
|
||||||
value, (0, padded_head_size - self.head_size), value=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
|
|
||||||
# Write input keys and values to the KV cache.
|
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
|
||||||
write_to_kv_cache(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
slot_mapping,
|
|
||||||
attn_metadata.num_slices_per_kv_cache_update_block,
|
|
||||||
attn_metadata.num_kv_update_slices,
|
|
||||||
self.kv_cache_quantized_dtype,
|
|
||||||
layer._k_scale_float,
|
|
||||||
layer._v_scale_float,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.kv_cache_quantized_dtype is not None and (
|
|
||||||
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0
|
|
||||||
):
|
|
||||||
raise ValueError("k_scale_float and v_scale_float must be non-zero")
|
|
||||||
output = torch.ops.xla.ragged_paged_attention(
|
|
||||||
query,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata.context_lens,
|
|
||||||
attn_metadata.block_tables,
|
|
||||||
attn_metadata.query_start_loc,
|
|
||||||
attn_metadata.num_seqs,
|
|
||||||
# By default, the system utilizes optimized block size and
|
|
||||||
# vmem_limit_bytes parameters from the kernel repository. However,
|
|
||||||
# these can be manually adjusted for debugging if necessary.
|
|
||||||
num_kv_pages_per_block=None,
|
|
||||||
num_queries_per_block=None,
|
|
||||||
vmem_limit_bytes=None,
|
|
||||||
use_kernel=True,
|
|
||||||
sm_scale=self.scale,
|
|
||||||
sliding_window=self.sliding_window,
|
|
||||||
soft_cap=self.logits_soft_cap,
|
|
||||||
k_scale=layer._k_scale_float,
|
|
||||||
v_scale=layer._v_scale_float,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
|
||||||
output = output[:, :, : self.head_size]
|
|
||||||
|
|
||||||
return output.reshape(num_tokens, hidden_size)
|
|
||||||
|
|
||||||
|
|
||||||
def write_to_kv_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
|
||||||
num_slices_per_kv_cache_update_block: int,
|
|
||||||
num_kv_update_slices: torch.Tensor,
|
|
||||||
kv_cache_quantized_dtype: torch.dtype | None = None,
|
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
) -> None:
|
|
||||||
"""Write the key and values to the KV cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
|
||||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
|
||||||
kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
|
||||||
num_slices_per_kv_cache_update_block: int
|
|
||||||
"""
|
|
||||||
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
|
|
||||||
head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
|
||||||
|
|
||||||
if kv_cache_quantized_dtype is not None:
|
|
||||||
dtype_info = torch.finfo(kv_cache_quantized_dtype)
|
|
||||||
key = key.to(torch.float32) / k_scale
|
|
||||||
# NOTE: clamp is added here to avoid out of range of quantized dtype
|
|
||||||
key = torch.clamp(key, dtype_info.min, dtype_info.max)
|
|
||||||
key = key.to(kv_cache_quantized_dtype)
|
|
||||||
value = value.to(torch.float32) / v_scale
|
|
||||||
value = torch.clamp(value, dtype_info.min, dtype_info.max)
|
|
||||||
value = value.to(kv_cache_quantized_dtype)
|
|
||||||
|
|
||||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size)
|
|
||||||
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
|
||||||
|
|
||||||
kv_cache = kv_cache.flatten(0, 1)
|
|
||||||
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
|
||||||
kv,
|
|
||||||
slot_mapping,
|
|
||||||
kv_cache,
|
|
||||||
num_kv_update_slices,
|
|
||||||
page_size,
|
|
||||||
num_slices_per_kv_cache_update_block,
|
|
||||||
)
|
|
||||||
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
|
||||||
kv_cache.copy_(new_kv_cache)
|
|
||||||
|
|
||||||
|
|
||||||
# We can move this function to a common utils file if it's also useful for other
|
|
||||||
# hardware.
|
|
||||||
def dtype_bits(dtype: torch.dtype):
|
|
||||||
if dtype.is_floating_point:
|
|
||||||
try:
|
|
||||||
return torch.finfo(dtype).bits
|
|
||||||
except TypeError:
|
|
||||||
pass
|
|
||||||
elif dtype.is_complex:
|
|
||||||
if dtype is torch.complex32:
|
|
||||||
return 32
|
|
||||||
elif dtype is torch.complex64:
|
|
||||||
return 64
|
|
||||||
elif dtype is torch.complex128:
|
|
||||||
return 128
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
return torch.iinfo(dtype).bits
|
|
||||||
# torch.iinfo cannot support int4, int2, bits8...
|
|
||||||
except TypeError:
|
|
||||||
pass
|
|
||||||
str_dtype = str(dtype)
|
|
||||||
# support torch.int4, torch.int5, torch.uint5...
|
|
||||||
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
|
|
||||||
return int(str_dtype[-1])
|
|
||||||
raise TypeError(f"Getting the bit width of {dtype} is not supported")
|
|
||||||
|
|
||||||
|
|
||||||
def get_dtype_packing(dtype):
|
|
||||||
bits = dtype_bits(dtype)
|
|
||||||
if 32 % bits != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The bit width must be divisible by 32, but got bits={bits}, "
|
|
||||||
"dtype={dtype}"
|
|
||||||
)
|
|
||||||
return 32 // bits
|
|
||||||
|
|
||||||
|
|
||||||
def get_page_size_bytes(
|
|
||||||
block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype
|
|
||||||
) -> int:
|
|
||||||
"""Returns the size in bytes of one page of the KV cache."""
|
|
||||||
padded_head_size = (
|
|
||||||
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
|
||||||
)
|
|
||||||
num_combined_kv_heads = num_kv_heads * 2
|
|
||||||
|
|
||||||
# NOTE: for the implicit padding in XLA
|
|
||||||
packing = get_dtype_packing(kv_cache_dtype)
|
|
||||||
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
|
|
||||||
|
|
||||||
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
|
|
||||||
return (
|
|
||||||
block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8
|
|
||||||
)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -2,350 +2,16 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""A TPU worker class."""
|
"""A TPU worker class."""
|
||||||
|
|
||||||
import os
|
from typing import TypeVar
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any, TypeVar
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
|
||||||
from vllm.distributed import (
|
|
||||||
ensure_model_parallel_initialized,
|
|
||||||
init_distributed_environment,
|
|
||||||
)
|
|
||||||
from vllm.distributed.kv_transfer import (
|
|
||||||
ensure_kv_transfer_initialized,
|
|
||||||
)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.model_executor import set_random_seed
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
||||||
from vllm.tasks import SupportedTask
|
|
||||||
from vllm.utils.math_utils import cdiv
|
|
||||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
|
||||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
|
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
|
||||||
from vllm.v1.utils import report_usage_stats
|
|
||||||
from vllm.v1.worker.utils import bind_kv_cache
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_R = TypeVar("_R")
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
if not USE_TPU_INFERENCE:
|
# TODO(weiyulin) Remove this file after adding an official way to use hardware plugin
|
||||||
logger.info("tpu_inference not found, using vLLM's TPUWorker.")
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
import torch_xla.debug.profiler as xp
|
|
||||||
import torch_xla.runtime as xr
|
|
||||||
|
|
||||||
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
|
||||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
|
||||||
|
|
||||||
|
|
||||||
class TPUWorker:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
local_rank: int,
|
|
||||||
rank: int,
|
|
||||||
distributed_init_method: str,
|
|
||||||
is_driver_worker: bool = False,
|
|
||||||
):
|
|
||||||
self.is_driver_worker = is_driver_worker
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.model_config = vllm_config.model_config
|
|
||||||
self.cache_config = vllm_config.cache_config
|
|
||||||
self.lora_config = vllm_config.lora_config
|
|
||||||
self.load_config = vllm_config.load_config
|
|
||||||
self.parallel_config = vllm_config.parallel_config
|
|
||||||
self.use_spmd = envs.VLLM_XLA_USE_SPMD
|
|
||||||
self.original_parallel_config = None
|
|
||||||
if self.use_spmd:
|
|
||||||
# Under SPMD mode, distributed env is initialized as if there is
|
|
||||||
# only one worker/device.
|
|
||||||
self.original_parallel_config = self.parallel_config
|
|
||||||
self.parallel_config.tensor_parallel_size = 1
|
|
||||||
self.parallel_config.pipeline_parallel_size = 1
|
|
||||||
self.parallel_config.world_size = 1
|
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
|
||||||
self.device_config = vllm_config.device_config
|
|
||||||
self.speculative_config = vllm_config.speculative_config
|
|
||||||
self.observability_config = vllm_config.observability_config
|
|
||||||
|
|
||||||
self.parallel_config.rank = rank
|
|
||||||
self.local_rank = local_rank
|
|
||||||
self.rank = rank
|
|
||||||
self.distributed_init_method = distributed_init_method
|
|
||||||
|
|
||||||
if self.cache_config.cache_dtype == "auto":
|
|
||||||
self.cache_dtype = self.model_config.dtype
|
|
||||||
else:
|
|
||||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
|
|
||||||
|
|
||||||
if self.model_config.trust_remote_code:
|
|
||||||
# note: lazy import to avoid importing torch before initializing
|
|
||||||
from vllm.utils.import_utils import init_cached_hf_modules
|
|
||||||
|
|
||||||
init_cached_hf_modules()
|
|
||||||
|
|
||||||
# Delay profiler initialization to the start of the profiling.
|
|
||||||
# This is because in vLLM V1, MP runtime is initialized before the
|
|
||||||
# TPU Worker is initialized. The profiler server needs to start after
|
|
||||||
# MP runtime is initialized.
|
|
||||||
self.profiler = None
|
|
||||||
self.profile_dir = None
|
|
||||||
if vllm_config.profiler_config.profiler == "torch" and self.rank < 1:
|
|
||||||
# For TPU, we can only have 1 active profiler session for 1 profiler
|
|
||||||
# server. So we only profile on rank0.
|
|
||||||
self.profile_dir = vllm_config.profiler_config.torch_profiler_dir
|
|
||||||
logger.info(
|
|
||||||
"Profiling enabled. Traces will be saved to: %s", self.profile_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
|
||||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
||||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
||||||
|
|
||||||
def init_device(self):
|
|
||||||
os.environ["PJRT_DEVICE"] = "TPU"
|
|
||||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
|
||||||
# ring, the xla tpu compiler flag
|
|
||||||
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
|
|
||||||
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
|
||||||
os.environ["LIBTPU_INIT_ARGS"] = (
|
|
||||||
os.environ.get("LIBTPU_INIT_ARGS", "")
|
|
||||||
+ " --xla_tpu_force_1d_allreduce_at_chunk_count=1"
|
|
||||||
" --xla_jf_conv_input_fusion=False"
|
|
||||||
)
|
|
||||||
# --xla_jf_conv_input_fusion=False is used to improve the perf of
|
|
||||||
# quantized matmul.
|
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
torch.set_default_dtype(self.model_config.dtype)
|
|
||||||
|
|
||||||
# Initialize the distributed environment.
|
|
||||||
self._init_tpu_worker_distributed_environment(
|
|
||||||
self.vllm_config, self.rank, self.distributed_init_method, self.local_rank
|
|
||||||
)
|
|
||||||
|
|
||||||
# Device initialization should happen after initializing
|
|
||||||
# the distributed runtime.
|
|
||||||
self.device = xm.xla_device()
|
|
||||||
self.device_config.device = self.device
|
|
||||||
|
|
||||||
# Set random seed.
|
|
||||||
set_random_seed(self.model_config.seed)
|
|
||||||
xm.set_rng_state(self.model_config.seed, self.device)
|
|
||||||
|
|
||||||
# Increase the cache size limit, which is the maximum number of
|
|
||||||
# dynamo graphs that can be compiled.
|
|
||||||
# TODO (NickLucche) On gsm we compile 80+ graphs.
|
|
||||||
# Re-evaluate limit, with MM we may get close to this limit.
|
|
||||||
torch._dynamo.config.cache_size_limit = 128
|
|
||||||
# Use persistent cache to avoid XLA recompilation.
|
|
||||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
|
||||||
# can have slightly different XLA graphs.
|
|
||||||
world_size = self.parallel_config.world_size
|
|
||||||
rank = xr.global_ordinal()
|
|
||||||
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
|
|
||||||
# Consequently, changes in optimization flags, which affect compilation
|
|
||||||
# results, don't change the cache key. This can result in the wrong
|
|
||||||
# compilation being used. To prevent this, disabling the XLA compilation
|
|
||||||
# cache during development is recommended.We can disable it by
|
|
||||||
# `export VLLM_XLA_CACHE_PATH=`
|
|
||||||
if envs.VLLM_XLA_CACHE_PATH:
|
|
||||||
per_rank_path = os.path.join(
|
|
||||||
envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}"
|
|
||||||
)
|
|
||||||
xr.initialize_cache(per_rank_path, readonly=False)
|
|
||||||
|
|
||||||
# Init ModelRunner here, so that we have access to self.device.
|
|
||||||
self.model_runner = TPUModelRunner(
|
|
||||||
self.vllm_config, self.device, self.original_parallel_config
|
|
||||||
)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
# If usage stat is enabled, collect relevant info.
|
|
||||||
report_usage_stats(self.vllm_config)
|
|
||||||
|
|
||||||
def determine_available_memory(self) -> int:
|
|
||||||
kv_caches: dict[str, torch.Tensor] = {}
|
|
||||||
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
|
||||||
for layer_name, layer_spec in kv_cache_spec.items():
|
|
||||||
if isinstance(layer_spec, AttentionSpec):
|
|
||||||
dtype = layer_spec.dtype
|
|
||||||
|
|
||||||
# Use an empty tensor instead of `None` to force Dynamo to pass
|
|
||||||
# it by reference, rather by specializing on the value `None`.
|
|
||||||
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
|
|
||||||
kv_caches[layer_name] = tpu_kv_cache
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Unsupported KV cache spec '{type(layer_spec)}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
runner_kv_caches: list[torch.Tensor] = []
|
|
||||||
bind_kv_cache(
|
|
||||||
kv_caches,
|
|
||||||
self.vllm_config.compilation_config.static_forward_context,
|
|
||||||
runner_kv_caches,
|
|
||||||
)
|
|
||||||
|
|
||||||
# `max_num_tokens >= max_num_batched_tokens` due to padding.
|
|
||||||
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
|
|
||||||
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
|
||||||
|
|
||||||
# Synchronize before measuring the memory usage.
|
|
||||||
xm.wait_device_ops()
|
|
||||||
|
|
||||||
# During the profiling run, the model runs without KV cache. After
|
|
||||||
# the profiling run, the model always runs with KV cache. Here we clear
|
|
||||||
# the dynamo cache and cached bytecode to ensure the model always has
|
|
||||||
# one compiled bytecode. Having one FX graph/cached bytecode per
|
|
||||||
# compiled model is required for `support_torch_compile` decorator to
|
|
||||||
# skip dynamo guard.
|
|
||||||
with set_current_vllm_config(self.vllm_config):
|
|
||||||
self.model_runner.reset_dynamo_cache()
|
|
||||||
|
|
||||||
# Get the maximum amount of memory used by the model weights and
|
|
||||||
# intermediate activations.
|
|
||||||
if self.use_spmd:
|
|
||||||
# This is a workaround for the TPU SPMD mode. The get_memory_info
|
|
||||||
# API doesn't work with SPMD mode in PyTorch/XLA.
|
|
||||||
# TODO: use xm.get_memory_info for SPMD once it's supported in
|
|
||||||
# PyTorch/XLA.
|
|
||||||
import tpu_info
|
|
||||||
|
|
||||||
chip_type, _ = tpu_info.device.get_local_chips()
|
|
||||||
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
|
|
||||||
total_memory_size = device_usage[0].total_memory
|
|
||||||
current_mem = device_usage[0].memory_usage
|
|
||||||
else:
|
|
||||||
m = xm.get_memory_info(self.device)
|
|
||||||
total_memory_size = m["bytes_limit"]
|
|
||||||
current_mem = m["bytes_used"]
|
|
||||||
# Ideally we would use profiled = m["peak_bytes_used"] to
|
|
||||||
# get weights + activations. But there is memory used during
|
|
||||||
# compilation / weight loading that impacts the peak and
|
|
||||||
# there is no way to reset peak memory in XLA, So we
|
|
||||||
# use the heuristic of 2% of weights.
|
|
||||||
profiled = current_mem * 1.02
|
|
||||||
|
|
||||||
# Calculate the TPU KV cache size based on profiling.
|
|
||||||
usable_memory_size = int(
|
|
||||||
total_memory_size * self.cache_config.gpu_memory_utilization
|
|
||||||
)
|
|
||||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
|
||||||
head_size = self.model_config.get_head_size()
|
|
||||||
if head_size > 0:
|
|
||||||
padded_head_size = (
|
|
||||||
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
|
||||||
)
|
|
||||||
if padded_head_size != head_size:
|
|
||||||
logger.warning_once("head size is padded to %d", padded_head_size)
|
|
||||||
# We adjust the usable memory size for the KV cache to prevent OOM
|
|
||||||
# errors, even after padding the head_size.
|
|
||||||
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
|
|
||||||
return int(tpu_kv_cache_bytes)
|
|
||||||
|
|
||||||
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
|
|
||||||
return self.model_runner.sample_tokens(grammar_output)
|
|
||||||
|
|
||||||
def execute_model(
|
|
||||||
self, scheduler_output: "SchedulerOutput"
|
|
||||||
) -> ModelRunnerOutput | None:
|
|
||||||
return self.model_runner.execute_model(scheduler_output)
|
|
||||||
|
|
||||||
def profile(self, is_start: bool = True):
|
|
||||||
if self.rank < 1:
|
|
||||||
if self.profile_dir is None:
|
|
||||||
raise RuntimeError("Profiler is not enabled.")
|
|
||||||
if is_start:
|
|
||||||
if self.profiler is None:
|
|
||||||
self.profiler = xp.start_server(9012)
|
|
||||||
xp.start_trace(self.profile_dir)
|
|
||||||
else:
|
|
||||||
xp.stop_trace()
|
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
||||||
return self.model_runner.add_lora(lora_request)
|
|
||||||
|
|
||||||
def load_model(self) -> None:
|
|
||||||
self.model_runner.load_model()
|
|
||||||
|
|
||||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
|
||||||
self.model_runner.update_config(overrides)
|
|
||||||
|
|
||||||
def reload_weights(self) -> None:
|
|
||||||
self.model_runner.reload_weights()
|
|
||||||
|
|
||||||
def compile_or_warm_up_model(self) -> None:
|
|
||||||
if not self.model_config.enforce_eager:
|
|
||||||
self.model_runner.capture_model()
|
|
||||||
|
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
|
||||||
# the model initialization and profiling.
|
|
||||||
set_random_seed(self.model_config.seed)
|
|
||||||
|
|
||||||
def reset_mm_cache(self) -> None:
|
|
||||||
self.model_runner.reset_mm_cache()
|
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
|
||||||
return self.model_runner.get_model()
|
|
||||||
|
|
||||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
|
||||||
return self.model_runner.get_supported_tasks()
|
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
||||||
return self.model_runner.get_kv_cache_spec()
|
|
||||||
|
|
||||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
|
||||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
|
||||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
|
||||||
|
|
||||||
def check_health(self) -> None:
|
|
||||||
# worker will always be healthy as long as it's running.
|
|
||||||
return
|
|
||||||
|
|
||||||
def _init_tpu_worker_distributed_environment(
|
|
||||||
self,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
rank: int,
|
|
||||||
distributed_init_method: str | None = None,
|
|
||||||
local_rank: int = -1,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the distributed environment."""
|
|
||||||
if self.use_spmd:
|
|
||||||
xr.use_spmd()
|
|
||||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
|
||||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
|
||||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
|
||||||
# own context.
|
|
||||||
parallel_config = vllm_config.parallel_config
|
|
||||||
init_distributed_environment(
|
|
||||||
world_size=parallel_config.world_size,
|
|
||||||
rank=rank,
|
|
||||||
local_rank=local_rank,
|
|
||||||
distributed_init_method=distributed_init_method or "env://",
|
|
||||||
backend=current_platform.dist_backend,
|
|
||||||
)
|
|
||||||
ensure_model_parallel_initialized(
|
|
||||||
parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
|
|
||||||
)
|
|
||||||
|
|
||||||
ensure_kv_transfer_initialized(vllm_config)
|
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
|
||||||
self.model_runner.ensure_kv_transfer_shutdown()
|
|
||||||
|
|
||||||
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
|
|
||||||
"""Apply a function on the model inside this worker."""
|
|
||||||
return fn(self.get_model())
|
|
||||||
|
|
||||||
|
|
||||||
if USE_TPU_INFERENCE:
|
if USE_TPU_INFERENCE:
|
||||||
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
|
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user