[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend (#15655)

Signed-off-by: Akshat Tripathi <akshat@krai.ai>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: xihajun <junfan@krai.ai>
Signed-off-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk>
Signed-off-by: Jorge de Freitas <jorge@krai.ai>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: xihajun <junfan@krai.ai>
Co-authored-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk>
Co-authored-by: Jorge de Freitas <jorge@krai.ai>
This commit is contained in:
Akshat Tripathi 2025-05-28 20:59:09 +01:00 committed by GitHub
parent a09c7ca9f2
commit 643622ba46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 325 additions and 334 deletions

View File

@ -122,10 +122,8 @@ run_and_track_test 11 "test_struct_output_generate.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py"
run_and_track_test 12 "test_moe_pallas.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
# Disable the TPU LoRA tests until the feature is activated
# run_and_track_test 13 "test_lora (directory)" \
# "python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/"
run_and_track_test 13 "test_lora.py" \
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then

View File

@ -1,73 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
# Required to register the custom ops
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
N_TOKENS = [16, 1024, 4096]
HIDDEN_SIZES = [1024, 2048, 4096]
DTYPES = [torch.bfloat16]
NUM_LORA = [1, 4, 16]
RANKS = [32, 256, 512]
def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
"""
Inputs: (All integers)
T: Total number of tokens
D: Input dim
L: LoRA Dim
N: N LoRAs
Outputs:
inputs: torch.Tensor - shape (T, D)
loras: torch.Tensor - shape (N, 1, L, D)
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
"""
torch.manual_seed(seed)
inputs = torch.randn((T, D), device="xla", dtype=dtype)
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype)
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla")
ref_output = ref_bgmv(inputs, loras, idxs)
return inputs, loras, idxs, ref_output
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
selected_loras = loras[idxs]
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(axis=1)
batch_size, output_size, input_size = selected_loras.shape
return (selected_loras @ inputs.reshape(
(batch_size, input_size, 1))).reshape((batch_size, output_size))
# Parameterize tests with various shapes and dtypes
@pytest.mark.parametrize("T", N_TOKENS)
@pytest.mark.parametrize("D", HIDDEN_SIZES)
@pytest.mark.parametrize("L", RANKS)
@pytest.mark.parametrize("N", NUM_LORA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", [0])
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
if op_type == "expand":
D, L = L, D
inputs, loras, idxs, ref_output = generate_test_data(
T, D, L, N, seed, dtype)
# Run bgmv
output = torch.ops.xla.bgmv(inputs, loras, idxs)
# Make sure we have no NaNs
assert not torch.any(torch.isnan(output))
# Compare with reference output
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)

View File

@ -1,16 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
import jax
import jax.numpy as jnp
import torch
# Required to register the custom ops
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
import torch.nn.functional as F
import torch_xla.core.xla_builder as xb
from torch.library import impl
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
def bgmv_expand(inputs: torch.Tensor,
@jax.jit
def bgmv_jax(inputs, loras, idxs):
return jnp.einsum(
"td,tX,Xld->tl",
inputs,
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
loras,
)
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
jax_import_guard()
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
idxs: torch.IntTensor):
T, _ = inputs.shape
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
_, L, _ = loras.shape
return torch.empty((T, L), device=inputs.device)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
@ -28,29 +66,27 @@ def bgmv_expand(inputs: torch.Tensor,
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
n_tokens = outputs.size(0)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
outputs = torch.cat(
(outputs,
torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]),
device=outputs.device)),
dim=1)
if output_tensor.shape[1] > outputs.shape[1]:
outputs = F.pad(outputs,
(0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
if add_inputs:
return output_tensor + outputs[:limit, :]
return output_tensor + outputs[:limit, :output_tensor.shape[1]]
else:
return outputs[:limit, :]
return outputs[:limit, :output_tensor.shape[1]]
def bgmv_shrink(inputs: torch.Tensor,
def bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
scaling: float = 1.0,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
@ -66,13 +102,15 @@ def bgmv_shrink(inputs: torch.Tensor,
lora_indices_tensor)
def bgmv_expand_slice(inputs: torch.Tensor,
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
@ -89,16 +127,16 @@ def bgmv_expand_slice(inputs: torch.Tensor,
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
n_tokens = outputs.size(0)
outputs = torch.cat((
torch.zeros((n_tokens, slice_offset), device=outputs.device),
outputs = F.pad(
outputs,
torch.zeros(
(n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)),
device=outputs.device),
(
slice_offset,
output_tensor.shape[1] - (slice_offset + slice_size),
0,
0,
),
dim=1)
)
if add_inputs:
return output_tensor + outputs

View File

@ -1,133 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import functools
import jax
import jax.numpy as jnp
import torch
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from torch.library import impl
from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard,
make_kernel_from_pallas)
# TODO: Tune these
TOKENS_BLOCK = 16
LORA_RANK_BLOCK = 128
DIM_BLOCK_SIZE = 128
def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref,
acc_ref, mask_ref):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32)
t = pl.program_id(0)
for i in range(bT):
idx = idx_ref[i + bT * t]
mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32)
mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32)
acc_ref[...] += jax.lax.dot_general(
inp_ref[...],
lora_ref[idx, ...], (((1, ), (1, )), ((), ())),
preferred_element_type=jnp.float32) * mask_ref[...]
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
def _():
out_ref[...] = acc_ref[...].astype(out_ref.dtype)
@jax.jit
def _bgmv(
idxs: jax.Array, # (T, ) int32
inputs: jax.Array, # (T, D) model dtype
loras: jax.Array # (N, L, D) model dtype
) -> jax.Array: # (T, L) model dtype
T, D = inputs.shape
N, L, _ = loras.shape
return pl.pallas_call(
kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK),
out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK,
D // DIM_BLOCK_SIZE),
in_specs=[
pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE),
lambda i, j, k, block_idx: (i, k)),
pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE),
lambda i, j, k, block_idx: (0, j, k)),
],
out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK),
lambda i, j, k, block_idx: (i, j)),
scratch_shapes=[
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32),
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32)
]),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
name="bgmv")(idxs, inputs, loras)
def bgmv_shape_function(idxs, inputs, loras):
T, _ = inputs.shape
_, L, _ = loras.shape
return [((T, L), inputs.dtype)]
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", )
@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
inputs = inputs.to(dtype=loras.dtype)
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
jax_import_guard()
kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function)
T, _ = inputs.shape
_, L, D = loras.shape
# Pad the loras' rank if it's too low. This is to allow it to fit in a TPU
# register. This has to happen in pytorch, doing it in Jax will lead to NaNs
L1 = L
if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0:
L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK
D1 = D
if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0:
D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE
T1 = T
if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0:
T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK
if D1 != D or L1 != L:
loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0))
if D1 != D or T1 != T:
inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T))
if T1 != T:
idxs = torch.nn.functional.pad(idxs, ((0, T1 - T)))
return kernel(idxs, inputs, loras)[:T, :L]
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
idxs: torch.IntTensor):
T, _ = inputs.shape
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
_, L, _ = loras.shape
return torch.empty((T, L), device=inputs.device)

View File

@ -1,11 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Union
import math
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch.nn.functional as F
import torch_xla.core.xla_model as xm
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.punica_wrapper.utils import convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
from .punica_base import PunicaWrapperBase
@ -31,6 +39,15 @@ class PunicaWrapperTPU(PunicaWrapperBase):
self._sampler_indices_padded = self._sampler_indices_padded.to(
dtype=torch.int32)
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded,
True)
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch,
True)
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
@ -55,15 +72,11 @@ class PunicaWrapperTPU(PunicaWrapperBase):
def shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
if self.no_lora:
return y
return bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x),
scale)
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor,
add_inputs: bool):
@ -72,7 +85,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
def expand_slice(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, y_offset: int, y_slice_size: int,
y_total_size: int, add_inputs: bool) -> torch.Tensor:
add_inputs: bool) -> torch.Tensor:
return bgmv_expand_slice(x, w_t_all, y,
self._get_token_lora_indices(x), y_offset,
y_slice_size, add_inputs)
@ -98,9 +111,8 @@ class PunicaWrapperTPU(PunicaWrapperBase):
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)):
y_s = y[slice_idx]
lora_s = lora_a_stacked[slice_idx]
y_s = self.shrink(y_s, x, lora_s, scale)
y_s = self.shrink(x, lora_s, scale)
y[slice_idx, :, :] = y_s # type: ignore[index]
return y
@ -140,15 +152,12 @@ class PunicaWrapperTPU(PunicaWrapperBase):
y = self._apply_bias(self._get_token_lora_indices(y), y,
output_slices, lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
y = self.expand_slice(
y,
y = self.expand_slice(y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
y_total_size=sum(output_slices),
add_inputs=add_inputs,
)
add_inputs=add_inputs)
offset_left += output_slices[slice_idx]
return y.view_as(y_org)
@ -216,12 +225,10 @@ class PunicaWrapperTPU(PunicaWrapperBase):
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, consistent with the
# triton op
T = x.size(0)
buffer = torch.zeros(
(len(output_slices), T, r),
dtype=torch.float32,
dtype=x.dtype,
device=x.device,
)
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
@ -257,26 +264,16 @@ class PunicaWrapperTPU(PunicaWrapperBase):
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
if self.no_lora:
return y
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices,
scale)
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
y = bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
sampler_indices,
add_inputs=True)
return y.view_as(y_org)
@ -316,10 +313,92 @@ class PunicaWrapperTPU(PunicaWrapperBase):
return output.view_as(org_output)
# This performs the same tensor ops as the base method, except it does them
# on the CPU then transfers the results to the TPU
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
# Make sure we don't accidentally collect outside operations
xm.mark_step()
# Pad the prompt mapping to avoid running into recompiles on the TPU
# TODO: Should this happen inside mapping internally? If so how can we
# avoid having backend specific LoRAMapping classes?
mapping.prompt_mapping = self._pad_prompt_mapping(
mapping.prompt_mapping)
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
extra_vocab_size,
"cpu",
long_lora_context,
)
self._token_lora_indices = self._pad_to_shape(
base_indices, self._token_lora_indices.shape,
dims=1).to(self.device)
self._sampler_indices = self._pad_to_shape(sampler_indices,
self._sampler_indices.shape,
dims=1).to(self.device)
self._sampler_indices_padded = self._pad_to_shape(
sampler_indices_padded, self._sampler_indices_padded.shape,
dims=1).to(self.device)
self._embeddings_indices = self._pad_to_shape(
embeddings_indices, self._embeddings_indices.shape,
dims=2).to(self.device)
if long_lora_offsets_tensor is not None:
self._long_lora_indices = self._pad_to_shape(
long_lora_offsets_tensor,
self._long_lora_indices.shape,
dims=1).to(self.device)
else:
zeroed = torch.zeros_like(self._long_lora_indices.cpu(),
dtype=torch.int32)
self._long_lora_indices = zeroed.to(self.device)
self.indices_len[:] = indices_len
def _update_prefill_metadata(self,
token_lora_tensor: torch.Tensor) -> None:
self.batch_size = 1
self._lora_indices_per_batch[:self.batch_size].copy_(
token_lora_tensor[:self.batch_size])
# TODO: .item() is extremely inefficient on TPU, so find a way around it
self.no_lora = torch.all(token_lora_tensor == -1).item()
self._lora_indices_per_batch[:self.
batch_size] = token_lora_tensor[:self.
batch_size]
def _pad_prompt_mapping(
self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
num_reqs = len(prompt_mapping)
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
# import
MIN_NUM_SEQS = 8
padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
pad_len = padded_num_reqs - num_reqs
padding = [-1] * pad_len
return tuple(list(prompt_mapping) + padding)
def _pad_to_shape(self, src, target_shape, dims=1):
if dims == 1:
pad_len = target_shape[0] - src.shape[0]
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
else:
pad_rows = target_shape[0] - src.shape[0]
pad_cols = target_shape[1] - src.shape[1]
return F.pad(src, (0, pad_cols, 0, pad_rows),
value=0).to(torch.int32)

View File

@ -80,7 +80,37 @@ class LoRAModelRunnerMixin:
lora_requests)
@contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
def maybe_setup_dummy_loras(self, lora_config):
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_loras = lora_config.max_loras
# Make dummy lora requests
lora_requests: set[LoRARequest] = {
LoRARequest(lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path")
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(
lr, rank=self.LORA_WARMUP_RANK)
yield
# __exit__ code
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
if lora_config is None:
yield
@ -108,21 +138,18 @@ class LoRAModelRunnerMixin:
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(
lr, rank=self.LORA_WARMUP_RANK)
self._set_active_loras(tuple(prompt_lora_mapping),
tuple(token_lora_mapping),
lora_requests)
tuple(token_lora_mapping), lora_requests)
yield
# __exit__ code
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
with self.maybe_setup_dummy_loras(
lora_config), self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens):
yield
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:

View File

@ -20,6 +20,7 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
@ -152,6 +153,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.hidden_size = model_config.get_hidden_size()
self.vocab_size = model_config.get_vocab_size()
if self.lora_config is not None:
self.vocab_size += self.lora_config.lora_extra_vocab_size
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
@ -591,6 +595,17 @@ class TPUModelRunner(LoRAModelRunnerMixin):
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters
padded_num_scheduled_tokens_per_req = np.copy(
num_scheduled_tokens_per_req
) # Copying to avoid accidental state corruption bugs
padded_num_scheduled_tokens_per_req[-1] += \
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
self.set_active_loras(self.input_batch,
padded_num_scheduled_tokens_per_req)
layer_names = get_layers_from_vllm_config(self.vllm_config,
Attention).keys()
per_layer_attn_metadata = {
@ -916,6 +931,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
model = self.load_lora_model(model, self.model_config,
self.scheduler_config,
self.lora_config, self.device)
replace_set_lora(model)
# Sync all pending XLA execution during model initialization and weight
# loading.
@ -980,7 +996,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
for layer_name in layer_names
}
with self.maybe_dummy_run_with_lora(
with self.maybe_select_dummy_loras(
self.lora_config,
np.array([num_tokens], dtype=np.int32)), set_forward_context(
per_layer_attn_metadata, self.vllm_config, 0):
@ -989,6 +1005,13 @@ class TPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
lora_requests) -> None:
xm.mark_step() # Captures input updates
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
lora_requests)
xm.mark_step() # Captures metadata updates
def _precompile_mm_encoder(self) -> None:
# Pre-compile MM encoder for all supported data modalities.
hf_config = self.vllm_config.model_config.hf_config
@ -1151,6 +1174,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
generate_params_if_all_greedy,
))
sampling_metadata.all_greedy = all_greedy
with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs],
dtype=np.int32)):
self.sample_from_logits(dummy_logits, sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
@ -1167,6 +1193,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
dtype=self._hidden_states_dtype)
dummy_tokens = torch.zeros((num_reqs, 1),
dtype=torch.int64).to(self.device)
with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs], dtype=np.int32)):
self.gather_logprobs(dummy_logits, dummy_tokens)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
@ -1178,6 +1206,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"""
Precompile all the subgraphs with possible input shapes.
"""
with self.maybe_setup_dummy_loras(self.lora_config):
self._precompile_mm_encoder()
self._precompile_backbone()
self._precompile_select_hidden_states()
@ -1508,3 +1537,32 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
index = bisect.bisect_left(paddings, x)
assert index < len(paddings)
return paddings[index]
def replace_set_lora(model):
def _tpu_set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
# TODO: The integer index leads to a recompilation, but converting it
# to a tensor doesn't seem to work anymore. This might be fixed with a
# later release of torch_xla.
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
xm.mark_step()
def _tpu_reset_lora(self, index: int):
self._original_reset_lora(index)
xm.mark_step()
for _, module in model.named_modules():
if isinstance(module, BaseLayerWithLoRA):
module._original_set_lora = module.set_lora
module._original_reset_lora = module.reset_lora
module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
module.reset_lora = _tpu_reset_lora.__get__(
module, module.__class__)

View File

@ -83,10 +83,6 @@ class TPUWorker:
if self.model_config.seed is None:
self.model_config.seed = 0
if vllm_config.lora_config is not None:
raise NotImplementedError(
"The V1 TPU backend doesn't support LoRA serving")
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
@ -166,6 +162,7 @@ class TPUWorker:
runner_kv_caches)
# `max_num_tokens >= max_num_batched_tokens` due to padding.
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
self.model_runner.profile_run(self.model_runner.max_num_tokens)
# Synchronize before measuring the memory usage.