[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend (#15655)

Signed-off-by: Akshat Tripathi <akshat@krai.ai>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: xihajun <junfan@krai.ai>
Signed-off-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk>
Signed-off-by: Jorge de Freitas <jorge@krai.ai>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: xihajun <junfan@krai.ai>
Co-authored-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk>
Co-authored-by: Jorge de Freitas <jorge@krai.ai>
This commit is contained in:
Akshat Tripathi 2025-05-28 20:59:09 +01:00 committed by GitHub
parent a09c7ca9f2
commit 643622ba46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 325 additions and 334 deletions

View File

@ -122,10 +122,8 @@ run_and_track_test 11 "test_struct_output_generate.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py"
run_and_track_test 12 "test_moe_pallas.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
# Disable the TPU LoRA tests until the feature is activated
# run_and_track_test 13 "test_lora (directory)" \
# "python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/"
run_and_track_test 13 "test_lora.py" \
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then

View File

@ -1,73 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
# Required to register the custom ops
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
N_TOKENS = [16, 1024, 4096]
HIDDEN_SIZES = [1024, 2048, 4096]
DTYPES = [torch.bfloat16]
NUM_LORA = [1, 4, 16]
RANKS = [32, 256, 512]
def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
"""
Inputs: (All integers)
T: Total number of tokens
D: Input dim
L: LoRA Dim
N: N LoRAs
Outputs:
inputs: torch.Tensor - shape (T, D)
loras: torch.Tensor - shape (N, 1, L, D)
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
"""
torch.manual_seed(seed)
inputs = torch.randn((T, D), device="xla", dtype=dtype)
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype)
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla")
ref_output = ref_bgmv(inputs, loras, idxs)
return inputs, loras, idxs, ref_output
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
selected_loras = loras[idxs]
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(axis=1)
batch_size, output_size, input_size = selected_loras.shape
return (selected_loras @ inputs.reshape(
(batch_size, input_size, 1))).reshape((batch_size, output_size))
# Parameterize tests with various shapes and dtypes
@pytest.mark.parametrize("T", N_TOKENS)
@pytest.mark.parametrize("D", HIDDEN_SIZES)
@pytest.mark.parametrize("L", RANKS)
@pytest.mark.parametrize("N", NUM_LORA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", [0])
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
if op_type == "expand":
D, L = L, D
inputs, loras, idxs, ref_output = generate_test_data(
T, D, L, N, seed, dtype)
# Run bgmv
output = torch.ops.xla.bgmv(inputs, loras, idxs)
# Make sure we have no NaNs
assert not torch.any(torch.isnan(output))
# Compare with reference output
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)

View File

@ -200,7 +200,7 @@ class LoRAModel(AdapterModel):
weights_mapper: Optional[WeightsMapper] = None,
tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
@ -620,7 +620,7 @@ class LoRAModelManager(AdapterModelManager):
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:

View File

@ -1,63 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
import jax
import jax.numpy as jnp
import torch
# Required to register the custom ops
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
import torch.nn.functional as F
import torch_xla.core.xla_builder as xb
from torch.library import impl
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
@jax.jit
def bgmv_jax(inputs, loras, idxs):
return jnp.einsum(
"td,tX,Xld->tl",
inputs,
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
loras,
)
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
jax_import_guard()
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
idxs: torch.IntTensor):
T, _ = inputs.shape
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
_, L, _ = loras.shape
return torch.empty((T, L), device=inputs.device)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
n_tokens = outputs.size(0)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
outputs = torch.cat(
(outputs,
torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]),
device=outputs.device)),
dim=1)
if output_tensor.shape[1] > outputs.shape[1]:
outputs = F.pad(outputs,
(0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
if add_inputs:
return output_tensor + outputs[:limit, :]
return output_tensor + outputs[:limit, :output_tensor.shape[1]]
else:
return outputs[:limit, :]
return outputs[:limit, :output_tensor.shape[1]]
def bgmv_shrink(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
def bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
scaling (float, optional): Scalar multiplier applied to the output.
"""
@ -66,39 +102,41 @@ def bgmv_shrink(inputs: torch.Tensor,
lora_indices_tensor)
def bgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
n_tokens = outputs.size(0)
outputs = torch.cat((
torch.zeros((n_tokens, slice_offset), device=outputs.device),
outputs = F.pad(
outputs,
torch.zeros(
(n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)),
device=outputs.device),
),
dim=1)
(
slice_offset,
output_tensor.shape[1] - (slice_offset + slice_size),
0,
0,
),
)
if add_inputs:
return output_tensor + outputs

View File

@ -1,133 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import functools
import jax
import jax.numpy as jnp
import torch
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from torch.library import impl
from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard,
make_kernel_from_pallas)
# TODO: Tune these
TOKENS_BLOCK = 16
LORA_RANK_BLOCK = 128
DIM_BLOCK_SIZE = 128
def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref,
acc_ref, mask_ref):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32)
t = pl.program_id(0)
for i in range(bT):
idx = idx_ref[i + bT * t]
mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32)
mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32)
acc_ref[...] += jax.lax.dot_general(
inp_ref[...],
lora_ref[idx, ...], (((1, ), (1, )), ((), ())),
preferred_element_type=jnp.float32) * mask_ref[...]
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
def _():
out_ref[...] = acc_ref[...].astype(out_ref.dtype)
@jax.jit
def _bgmv(
idxs: jax.Array, # (T, ) int32
inputs: jax.Array, # (T, D) model dtype
loras: jax.Array # (N, L, D) model dtype
) -> jax.Array: # (T, L) model dtype
T, D = inputs.shape
N, L, _ = loras.shape
return pl.pallas_call(
kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK),
out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK,
D // DIM_BLOCK_SIZE),
in_specs=[
pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE),
lambda i, j, k, block_idx: (i, k)),
pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE),
lambda i, j, k, block_idx: (0, j, k)),
],
out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK),
lambda i, j, k, block_idx: (i, j)),
scratch_shapes=[
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32),
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32)
]),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
name="bgmv")(idxs, inputs, loras)
def bgmv_shape_function(idxs, inputs, loras):
T, _ = inputs.shape
_, L, _ = loras.shape
return [((T, L), inputs.dtype)]
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", )
@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
inputs = inputs.to(dtype=loras.dtype)
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
jax_import_guard()
kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function)
T, _ = inputs.shape
_, L, D = loras.shape
# Pad the loras' rank if it's too low. This is to allow it to fit in a TPU
# register. This has to happen in pytorch, doing it in Jax will lead to NaNs
L1 = L
if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0:
L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK
D1 = D
if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0:
D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE
T1 = T
if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0:
T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK
if D1 != D or L1 != L:
loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0))
if D1 != D or T1 != T:
inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T))
if T1 != T:
idxs = torch.nn.functional.pad(idxs, ((0, T1 - T)))
return kernel(idxs, inputs, loras)[:T, :L]
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
idxs: torch.IntTensor):
T, _ = inputs.shape
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
_, L, _ = loras.shape
return torch.empty((T, L), device=inputs.device)

View File

@ -1,11 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Union
import math
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch.nn.functional as F
import torch_xla.core.xla_model as xm
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.punica_wrapper.utils import convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
from .punica_base import PunicaWrapperBase
@ -31,6 +39,15 @@ class PunicaWrapperTPU(PunicaWrapperBase):
self._sampler_indices_padded = self._sampler_indices_padded.to(
dtype=torch.int32)
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded,
True)
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch,
True)
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
@ -55,15 +72,11 @@ class PunicaWrapperTPU(PunicaWrapperBase):
def shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
if self.no_lora:
return y
return bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x),
scale)
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor,
add_inputs: bool):
@ -72,7 +85,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
def expand_slice(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, y_offset: int, y_slice_size: int,
y_total_size: int, add_inputs: bool) -> torch.Tensor:
add_inputs: bool) -> torch.Tensor:
return bgmv_expand_slice(x, w_t_all, y,
self._get_token_lora_indices(x), y_offset,
y_slice_size, add_inputs)
@ -98,9 +111,8 @@ class PunicaWrapperTPU(PunicaWrapperBase):
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)):
y_s = y[slice_idx]
lora_s = lora_a_stacked[slice_idx]
y_s = self.shrink(y_s, x, lora_s, scale)
y_s = self.shrink(x, lora_s, scale)
y[slice_idx, :, :] = y_s # type: ignore[index]
return y
@ -140,15 +152,12 @@ class PunicaWrapperTPU(PunicaWrapperBase):
y = self._apply_bias(self._get_token_lora_indices(y), y,
output_slices, lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
y = self.expand_slice(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
y_total_size=sum(output_slices),
add_inputs=add_inputs,
)
y = self.expand_slice(y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs)
offset_left += output_slices[slice_idx]
return y.view_as(y_org)
@ -216,12 +225,10 @@ class PunicaWrapperTPU(PunicaWrapperBase):
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, consistent with the
# triton op
T = x.size(0)
buffer = torch.zeros(
(len(output_slices), T, r),
dtype=torch.float32,
dtype=x.dtype,
device=x.device,
)
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
@ -257,26 +264,16 @@ class PunicaWrapperTPU(PunicaWrapperBase):
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
if self.no_lora:
return y
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices,
scale)
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
y = bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
sampler_indices,
add_inputs=True)
return y.view_as(y_org)
@ -316,10 +313,92 @@ class PunicaWrapperTPU(PunicaWrapperBase):
return output.view_as(org_output)
# This performs the same tensor ops as the base method, except it does them
# on the CPU then transfers the results to the TPU
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
# Make sure we don't accidentally collect outside operations
xm.mark_step()
# Pad the prompt mapping to avoid running into recompiles on the TPU
# TODO: Should this happen inside mapping internally? If so how can we
# avoid having backend specific LoRAMapping classes?
mapping.prompt_mapping = self._pad_prompt_mapping(
mapping.prompt_mapping)
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
extra_vocab_size,
"cpu",
long_lora_context,
)
self._token_lora_indices = self._pad_to_shape(
base_indices, self._token_lora_indices.shape,
dims=1).to(self.device)
self._sampler_indices = self._pad_to_shape(sampler_indices,
self._sampler_indices.shape,
dims=1).to(self.device)
self._sampler_indices_padded = self._pad_to_shape(
sampler_indices_padded, self._sampler_indices_padded.shape,
dims=1).to(self.device)
self._embeddings_indices = self._pad_to_shape(
embeddings_indices, self._embeddings_indices.shape,
dims=2).to(self.device)
if long_lora_offsets_tensor is not None:
self._long_lora_indices = self._pad_to_shape(
long_lora_offsets_tensor,
self._long_lora_indices.shape,
dims=1).to(self.device)
else:
zeroed = torch.zeros_like(self._long_lora_indices.cpu(),
dtype=torch.int32)
self._long_lora_indices = zeroed.to(self.device)
self.indices_len[:] = indices_len
def _update_prefill_metadata(self,
token_lora_tensor: torch.Tensor) -> None:
self.batch_size = 1
self._lora_indices_per_batch[:self.batch_size].copy_(
token_lora_tensor[:self.batch_size])
# TODO: .item() is extremely inefficient on TPU, so find a way around it
self.no_lora = torch.all(token_lora_tensor == -1).item()
self._lora_indices_per_batch[:self.
batch_size] = token_lora_tensor[:self.
batch_size]
def _pad_prompt_mapping(
self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
num_reqs = len(prompt_mapping)
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
# import
MIN_NUM_SEQS = 8
padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
pad_len = padded_num_reqs - num_reqs
padding = [-1] * pad_len
return tuple(list(prompt_mapping) + padding)
def _pad_to_shape(self, src, target_shape, dims=1):
if dims == 1:
pad_len = target_shape[0] - src.shape[0]
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
else:
pad_rows = target_shape[0] - src.shape[0]
pad_cols = target_shape[1] - src.shape[1]
return F.pad(src, (0, pad_cols, 0, pad_rows),
value=0).to(torch.int32)

View File

@ -80,8 +80,38 @@ class LoRAModelRunnerMixin:
lora_requests)
@contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
def maybe_setup_dummy_loras(self, lora_config):
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_loras = lora_config.max_loras
# Make dummy lora requests
lora_requests: set[LoRARequest] = {
LoRARequest(lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path")
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(
lr, rank=self.LORA_WARMUP_RANK)
yield
# __exit__ code
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
if lora_config is None:
yield
else:
@ -108,21 +138,18 @@ class LoRAModelRunnerMixin:
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(
lr, rank=self.LORA_WARMUP_RANK)
self._set_active_loras(tuple(prompt_lora_mapping),
tuple(token_lora_mapping), lora_requests)
self._set_active_loras(tuple(prompt_lora_mapping),
tuple(token_lora_mapping),
lora_requests)
yield
yield
# __exit__ code
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
with self.maybe_setup_dummy_loras(
lora_config), self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens):
yield
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:

View File

@ -20,6 +20,7 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
@ -152,6 +153,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.hidden_size = model_config.get_hidden_size()
self.vocab_size = model_config.get_vocab_size()
if self.lora_config is not None:
self.vocab_size += self.lora_config.lora_extra_vocab_size
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
@ -591,6 +595,17 @@ class TPUModelRunner(LoRAModelRunnerMixin):
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters
padded_num_scheduled_tokens_per_req = np.copy(
num_scheduled_tokens_per_req
) # Copying to avoid accidental state corruption bugs
padded_num_scheduled_tokens_per_req[-1] += \
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
self.set_active_loras(self.input_batch,
padded_num_scheduled_tokens_per_req)
layer_names = get_layers_from_vllm_config(self.vllm_config,
Attention).keys()
per_layer_attn_metadata = {
@ -916,6 +931,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
model = self.load_lora_model(model, self.model_config,
self.scheduler_config,
self.lora_config, self.device)
replace_set_lora(model)
# Sync all pending XLA execution during model initialization and weight
# loading.
@ -980,7 +996,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
for layer_name in layer_names
}
with self.maybe_dummy_run_with_lora(
with self.maybe_select_dummy_loras(
self.lora_config,
np.array([num_tokens], dtype=np.int32)), set_forward_context(
per_layer_attn_metadata, self.vllm_config, 0):
@ -989,6 +1005,13 @@ class TPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
lora_requests) -> None:
xm.mark_step() # Captures input updates
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
lora_requests)
xm.mark_step() # Captures metadata updates
def _precompile_mm_encoder(self) -> None:
# Pre-compile MM encoder for all supported data modalities.
hf_config = self.vllm_config.model_config.hf_config
@ -1151,7 +1174,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
generate_params_if_all_greedy,
))
sampling_metadata.all_greedy = all_greedy
self.sample_from_logits(dummy_logits, sampling_metadata)
with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs],
dtype=np.int32)):
self.sample_from_logits(dummy_logits, sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
@ -1167,7 +1193,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
dtype=self._hidden_states_dtype)
dummy_tokens = torch.zeros((num_reqs, 1),
dtype=torch.int64).to(self.device)
self.gather_logprobs(dummy_logits, dummy_tokens)
with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs], dtype=np.int32)):
self.gather_logprobs(dummy_logits, dummy_tokens)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
@ -1178,13 +1206,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"""
Precompile all the subgraphs with possible input shapes.
"""
self._precompile_mm_encoder()
self._precompile_backbone()
self._precompile_select_hidden_states()
self._precompile_compute_logits()
self._precompile_structured_decoding()
self._precompile_sample_from_logits()
self._precompile_gather_logprobs()
with self.maybe_setup_dummy_loras(self.lora_config):
self._precompile_mm_encoder()
self._precompile_backbone()
self._precompile_select_hidden_states()
self._precompile_compute_logits()
self._precompile_structured_decoding()
self._precompile_sample_from_logits()
self._precompile_gather_logprobs()
def profile_run(
self,
@ -1467,11 +1496,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
If padding_gap == 0 then:
increase 2X each time (exponential)
else:
first increase the size to twice,
first increase the size to twice,
then increase the padding size by padding_gap.
"""
# assert min_token_size is power of 2
@ -1508,3 +1537,32 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
index = bisect.bisect_left(paddings, x)
assert index < len(paddings)
return paddings[index]
def replace_set_lora(model):
def _tpu_set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
# TODO: The integer index leads to a recompilation, but converting it
# to a tensor doesn't seem to work anymore. This might be fixed with a
# later release of torch_xla.
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
xm.mark_step()
def _tpu_reset_lora(self, index: int):
self._original_reset_lora(index)
xm.mark_step()
for _, module in model.named_modules():
if isinstance(module, BaseLayerWithLoRA):
module._original_set_lora = module.set_lora
module._original_reset_lora = module.reset_lora
module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
module.reset_lora = _tpu_reset_lora.__get__(
module, module.__class__)

View File

@ -83,10 +83,6 @@ class TPUWorker:
if self.model_config.seed is None:
self.model_config.seed = 0
if vllm_config.lora_config is not None:
raise NotImplementedError(
"The V1 TPU backend doesn't support LoRA serving")
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
@ -166,7 +162,8 @@ class TPUWorker:
runner_kv_caches)
# `max_num_tokens >= max_num_batched_tokens` due to padding.
self.model_runner.profile_run(self.model_runner.max_num_tokens)
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
self.model_runner.profile_run(self.model_runner.max_num_tokens)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()