mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend (#15655)
Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: xihajun <junfan@krai.ai> Signed-off-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk> Signed-off-by: Jorge de Freitas <jorge@krai.ai> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: xihajun <junfan@krai.ai> Co-authored-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk> Co-authored-by: Jorge de Freitas <jorge@krai.ai>
This commit is contained in:
parent
a09c7ca9f2
commit
643622ba46
@ -122,10 +122,8 @@ run_and_track_test 11 "test_struct_output_generate.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py"
|
||||
run_and_track_test 12 "test_moe_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
|
||||
|
||||
# Disable the TPU LoRA tests until the feature is activated
|
||||
# run_and_track_test 13 "test_lora (directory)" \
|
||||
# "python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/"
|
||||
run_and_track_test 13 "test_lora.py" \
|
||||
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
|
||||
|
||||
# After all tests have been attempted, exit with the overall status.
|
||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||
|
||||
@ -1,73 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Required to register the custom ops
|
||||
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
|
||||
|
||||
N_TOKENS = [16, 1024, 4096]
|
||||
HIDDEN_SIZES = [1024, 2048, 4096]
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
NUM_LORA = [1, 4, 16]
|
||||
RANKS = [32, 256, 512]
|
||||
|
||||
|
||||
def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
|
||||
"""
|
||||
Inputs: (All integers)
|
||||
T: Total number of tokens
|
||||
D: Input dim
|
||||
L: LoRA Dim
|
||||
N: N LoRAs
|
||||
|
||||
Outputs:
|
||||
inputs: torch.Tensor - shape (T, D)
|
||||
loras: torch.Tensor - shape (N, 1, L, D)
|
||||
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
|
||||
|
||||
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
inputs = torch.randn((T, D), device="xla", dtype=dtype)
|
||||
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype)
|
||||
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla")
|
||||
|
||||
ref_output = ref_bgmv(inputs, loras, idxs)
|
||||
return inputs, loras, idxs, ref_output
|
||||
|
||||
|
||||
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
|
||||
selected_loras = loras[idxs]
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(axis=1)
|
||||
|
||||
batch_size, output_size, input_size = selected_loras.shape
|
||||
return (selected_loras @ inputs.reshape(
|
||||
(batch_size, input_size, 1))).reshape((batch_size, output_size))
|
||||
|
||||
|
||||
# Parameterize tests with various shapes and dtypes
|
||||
@pytest.mark.parametrize("T", N_TOKENS)
|
||||
@pytest.mark.parametrize("D", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("L", RANKS)
|
||||
@pytest.mark.parametrize("N", NUM_LORA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", [0])
|
||||
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
|
||||
if op_type == "expand":
|
||||
D, L = L, D
|
||||
|
||||
inputs, loras, idxs, ref_output = generate_test_data(
|
||||
T, D, L, N, seed, dtype)
|
||||
|
||||
# Run bgmv
|
||||
output = torch.ops.xla.bgmv(inputs, loras, idxs)
|
||||
|
||||
# Make sure we have no NaNs
|
||||
assert not torch.any(torch.isnan(output))
|
||||
|
||||
# Compare with reference output
|
||||
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)
|
||||
@ -200,7 +200,7 @@ class LoRAModel(AdapterModel):
|
||||
weights_mapper: Optional[WeightsMapper] = None,
|
||||
tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a local checkpoint.
|
||||
|
||||
|
||||
Args:
|
||||
lora_dir: The local path that has lora data.
|
||||
expected_lora_modules: Name of modules that are expected to be
|
||||
@ -620,7 +620,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
|
||||
"""
|
||||
Regarding multimodal models, vLLM currently only supports adding LoRA to
|
||||
language model. LoRA for other modules, such as the vision tower, will
|
||||
language model. LoRA for other modules, such as the vision tower, will
|
||||
be filtered out.
|
||||
"""
|
||||
if self.supports_mm:
|
||||
|
||||
@ -1,63 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
|
||||
# Required to register the custom ops
|
||||
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
|
||||
import torch.nn.functional as F
|
||||
import torch_xla.core.xla_builder as xb
|
||||
from torch.library import impl
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
@jax.jit
|
||||
def bgmv_jax(inputs, loras, idxs):
|
||||
return jnp.einsum(
|
||||
"td,tX,Xld->tl",
|
||||
inputs,
|
||||
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
|
||||
loras,
|
||||
)
|
||||
|
||||
|
||||
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "XLA")
|
||||
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
|
||||
jax_import_guard()
|
||||
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
|
||||
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
|
||||
idxs: torch.IntTensor):
|
||||
T, _ = inputs.shape
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
_, L, _ = loras.shape
|
||||
|
||||
return torch.empty((T, L), device=inputs.device)
|
||||
|
||||
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
[num_tokens, hidden_size * num_slices].
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
tensor.
|
||||
"""
|
||||
|
||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
n_tokens = outputs.size(0)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
outputs = torch.cat(
|
||||
(outputs,
|
||||
torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]),
|
||||
device=outputs.device)),
|
||||
dim=1)
|
||||
if output_tensor.shape[1] > outputs.shape[1]:
|
||||
outputs = F.pad(outputs,
|
||||
(0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
|
||||
|
||||
if add_inputs:
|
||||
return output_tensor + outputs[:limit, :]
|
||||
return output_tensor + outputs[:limit, :output_tensor.shape[1]]
|
||||
else:
|
||||
return outputs[:limit, :]
|
||||
return outputs[:limit, :output_tensor.shape[1]]
|
||||
|
||||
|
||||
def bgmv_shrink(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
scaling (float, optional): Scalar multiplier applied to the output.
|
||||
"""
|
||||
@ -66,39 +102,41 @@ def bgmv_shrink(inputs: torch.Tensor,
|
||||
lora_indices_tensor)
|
||||
|
||||
|
||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
[num_tokens, hidden_size * num_slices].
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
tensor.
|
||||
"""
|
||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
n_tokens = outputs.size(0)
|
||||
|
||||
outputs = torch.cat((
|
||||
torch.zeros((n_tokens, slice_offset), device=outputs.device),
|
||||
outputs = F.pad(
|
||||
outputs,
|
||||
torch.zeros(
|
||||
(n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)),
|
||||
device=outputs.device),
|
||||
),
|
||||
dim=1)
|
||||
(
|
||||
slice_offset,
|
||||
output_tensor.shape[1] - (slice_offset + slice_size),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
|
||||
if add_inputs:
|
||||
return output_tensor + outputs
|
||||
|
||||
@ -1,133 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import functools
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
from torch.library import impl
|
||||
from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard,
|
||||
make_kernel_from_pallas)
|
||||
|
||||
# TODO: Tune these
|
||||
TOKENS_BLOCK = 16
|
||||
LORA_RANK_BLOCK = 128
|
||||
DIM_BLOCK_SIZE = 128
|
||||
|
||||
|
||||
def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref,
|
||||
acc_ref, mask_ref):
|
||||
|
||||
@pl.when(pl.program_id(2) == 0)
|
||||
def _():
|
||||
acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32)
|
||||
|
||||
t = pl.program_id(0)
|
||||
|
||||
for i in range(bT):
|
||||
idx = idx_ref[i + bT * t]
|
||||
mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32)
|
||||
mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32)
|
||||
|
||||
acc_ref[...] += jax.lax.dot_general(
|
||||
inp_ref[...],
|
||||
lora_ref[idx, ...], (((1, ), (1, )), ((), ())),
|
||||
preferred_element_type=jnp.float32) * mask_ref[...]
|
||||
|
||||
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
|
||||
def _():
|
||||
out_ref[...] = acc_ref[...].astype(out_ref.dtype)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _bgmv(
|
||||
idxs: jax.Array, # (T, ) int32
|
||||
inputs: jax.Array, # (T, D) model dtype
|
||||
loras: jax.Array # (N, L, D) model dtype
|
||||
) -> jax.Array: # (T, L) model dtype
|
||||
T, D = inputs.shape
|
||||
N, L, _ = loras.shape
|
||||
|
||||
return pl.pallas_call(
|
||||
kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK),
|
||||
out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK,
|
||||
D // DIM_BLOCK_SIZE),
|
||||
in_specs=[
|
||||
pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE),
|
||||
lambda i, j, k, block_idx: (i, k)),
|
||||
pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE),
|
||||
lambda i, j, k, block_idx: (0, j, k)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK),
|
||||
lambda i, j, k, block_idx: (i, j)),
|
||||
scratch_shapes=[
|
||||
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32),
|
||||
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32)
|
||||
]),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary")),
|
||||
name="bgmv")(idxs, inputs, loras)
|
||||
|
||||
|
||||
def bgmv_shape_function(idxs, inputs, loras):
|
||||
T, _ = inputs.shape
|
||||
_, L, _ = loras.shape
|
||||
|
||||
return [((T, L), inputs.dtype)]
|
||||
|
||||
|
||||
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", )
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "XLA")
|
||||
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
||||
inputs = inputs.to(dtype=loras.dtype)
|
||||
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
|
||||
jax_import_guard()
|
||||
kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function)
|
||||
|
||||
T, _ = inputs.shape
|
||||
_, L, D = loras.shape
|
||||
|
||||
# Pad the loras' rank if it's too low. This is to allow it to fit in a TPU
|
||||
# register. This has to happen in pytorch, doing it in Jax will lead to NaNs
|
||||
L1 = L
|
||||
if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0:
|
||||
L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK
|
||||
|
||||
D1 = D
|
||||
if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0:
|
||||
D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE
|
||||
|
||||
T1 = T
|
||||
if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0:
|
||||
T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK
|
||||
|
||||
if D1 != D or L1 != L:
|
||||
loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0))
|
||||
if D1 != D or T1 != T:
|
||||
inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T))
|
||||
if T1 != T:
|
||||
idxs = torch.nn.functional.pad(idxs, ((0, T1 - T)))
|
||||
|
||||
return kernel(idxs, inputs, loras)[:T, :L]
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
|
||||
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
|
||||
idxs: torch.IntTensor):
|
||||
T, _ = inputs.shape
|
||||
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
|
||||
_, L, _ = loras.shape
|
||||
|
||||
return torch.empty((T, L), device=inputs.device)
|
||||
@ -1,11 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Union
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
from vllm.lora.punica_wrapper.utils import convert_mapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.models import LongContextLoRAContext
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
@ -31,6 +39,15 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
self._sampler_indices_padded = self._sampler_indices_padded.to(
|
||||
dtype=torch.int32)
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded,
|
||||
True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch,
|
||||
True)
|
||||
|
||||
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
|
||||
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
|
||||
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
|
||||
@ -55,15 +72,11 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
|
||||
def shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
if self.no_lora:
|
||||
return y
|
||||
return bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x),
|
||||
scale)
|
||||
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
|
||||
|
||||
def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor,
|
||||
add_inputs: bool):
|
||||
@ -72,7 +85,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
|
||||
def expand_slice(self, y: torch.Tensor, x: torch.Tensor,
|
||||
w_t_all: torch.Tensor, y_offset: int, y_slice_size: int,
|
||||
y_total_size: int, add_inputs: bool) -> torch.Tensor:
|
||||
add_inputs: bool) -> torch.Tensor:
|
||||
return bgmv_expand_slice(x, w_t_all, y,
|
||||
self._get_token_lora_indices(x), y_offset,
|
||||
y_slice_size, add_inputs)
|
||||
@ -98,9 +111,8 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
y_s = y[slice_idx]
|
||||
lora_s = lora_a_stacked[slice_idx]
|
||||
y_s = self.shrink(y_s, x, lora_s, scale)
|
||||
y_s = self.shrink(x, lora_s, scale)
|
||||
y[slice_idx, :, :] = y_s # type: ignore[index]
|
||||
return y
|
||||
|
||||
@ -140,15 +152,12 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
y = self._apply_bias(self._get_token_lora_indices(y), y,
|
||||
output_slices, lora_bias_stacked)
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
y = self.expand_slice(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
y_total_size=sum(output_slices),
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
y = self.expand_slice(y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs)
|
||||
offset_left += output_slices[slice_idx]
|
||||
return y.view_as(y_org)
|
||||
|
||||
@ -216,12 +225,10 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
T = x.size(0)
|
||||
buffer = torch.zeros(
|
||||
(len(output_slices), T, r),
|
||||
dtype=torch.float32,
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
@ -257,26 +264,16 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]):Default to None.
|
||||
"""
|
||||
if self.no_lora:
|
||||
return y
|
||||
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = lora_b_stacked.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices,
|
||||
scale)
|
||||
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
|
||||
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
|
||||
y = bgmv_expand(buffer,
|
||||
lora_b_stacked,
|
||||
y,
|
||||
self.sampler_indices,
|
||||
sampler_indices,
|
||||
add_inputs=True)
|
||||
return y.view_as(y_org)
|
||||
|
||||
@ -316,10 +313,92 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
|
||||
return output.view_as(org_output)
|
||||
|
||||
# This performs the same tensor ops as the base method, except it does them
|
||||
# on the CPU then transfers the results to the TPU
|
||||
def _update_base_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: list[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
):
|
||||
# Make sure we don't accidentally collect outside operations
|
||||
xm.mark_step()
|
||||
|
||||
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
||||
# TODO: Should this happen inside mapping internally? If so how can we
|
||||
# avoid having backend specific LoRAMapping classes?
|
||||
mapping.prompt_mapping = self._pad_prompt_mapping(
|
||||
mapping.prompt_mapping)
|
||||
|
||||
(
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
long_lora_offsets_tensor,
|
||||
indices_len,
|
||||
) = convert_mapping(
|
||||
mapping,
|
||||
lora_index_to_id,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
extra_vocab_size,
|
||||
"cpu",
|
||||
long_lora_context,
|
||||
)
|
||||
self._token_lora_indices = self._pad_to_shape(
|
||||
base_indices, self._token_lora_indices.shape,
|
||||
dims=1).to(self.device)
|
||||
self._sampler_indices = self._pad_to_shape(sampler_indices,
|
||||
self._sampler_indices.shape,
|
||||
dims=1).to(self.device)
|
||||
self._sampler_indices_padded = self._pad_to_shape(
|
||||
sampler_indices_padded, self._sampler_indices_padded.shape,
|
||||
dims=1).to(self.device)
|
||||
self._embeddings_indices = self._pad_to_shape(
|
||||
embeddings_indices, self._embeddings_indices.shape,
|
||||
dims=2).to(self.device)
|
||||
if long_lora_offsets_tensor is not None:
|
||||
self._long_lora_indices = self._pad_to_shape(
|
||||
long_lora_offsets_tensor,
|
||||
self._long_lora_indices.shape,
|
||||
dims=1).to(self.device)
|
||||
else:
|
||||
zeroed = torch.zeros_like(self._long_lora_indices.cpu(),
|
||||
dtype=torch.int32)
|
||||
self._long_lora_indices = zeroed.to(self.device)
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def _update_prefill_metadata(self,
|
||||
token_lora_tensor: torch.Tensor) -> None:
|
||||
self.batch_size = 1
|
||||
self._lora_indices_per_batch[:self.batch_size].copy_(
|
||||
token_lora_tensor[:self.batch_size])
|
||||
# TODO: .item() is extremely inefficient on TPU, so find a way around it
|
||||
self.no_lora = torch.all(token_lora_tensor == -1).item()
|
||||
self._lora_indices_per_batch[:self.
|
||||
batch_size] = token_lora_tensor[:self.
|
||||
batch_size]
|
||||
|
||||
def _pad_prompt_mapping(
|
||||
self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
|
||||
num_reqs = len(prompt_mapping)
|
||||
|
||||
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
|
||||
# import
|
||||
MIN_NUM_SEQS = 8
|
||||
|
||||
padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
|
||||
pad_len = padded_num_reqs - num_reqs
|
||||
|
||||
padding = [-1] * pad_len
|
||||
return tuple(list(prompt_mapping) + padding)
|
||||
|
||||
def _pad_to_shape(self, src, target_shape, dims=1):
|
||||
if dims == 1:
|
||||
pad_len = target_shape[0] - src.shape[0]
|
||||
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
|
||||
else:
|
||||
pad_rows = target_shape[0] - src.shape[0]
|
||||
pad_cols = target_shape[1] - src.shape[1]
|
||||
return F.pad(src, (0, pad_cols, 0, pad_rows),
|
||||
value=0).to(torch.int32)
|
||||
|
||||
@ -80,8 +80,38 @@ class LoRAModelRunnerMixin:
|
||||
lora_requests)
|
||||
|
||||
@contextmanager
|
||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
def maybe_setup_dummy_loras(self, lora_config):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_loras = lora_config.max_loras
|
||||
|
||||
# Make dummy lora requests
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path")
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(
|
||||
lr, rank=self.LORA_WARMUP_RANK)
|
||||
|
||||
yield
|
||||
|
||||
# __exit__ code
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
@contextmanager
|
||||
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
@ -108,21 +138,18 @@ class LoRAModelRunnerMixin:
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(
|
||||
lr, rank=self.LORA_WARMUP_RANK)
|
||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping), lora_requests)
|
||||
|
||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping),
|
||||
lora_requests)
|
||||
yield
|
||||
|
||||
yield
|
||||
|
||||
# __exit__ code
|
||||
self.lora_manager.remove_all_adapters()
|
||||
@contextmanager
|
||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
with self.maybe_setup_dummy_loras(
|
||||
lora_config), self.maybe_select_dummy_loras(
|
||||
lora_config, num_scheduled_tokens):
|
||||
yield
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if not self.lora_manager:
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||
@ -152,6 +153,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
self.vocab_size = model_config.get_vocab_size()
|
||||
|
||||
if self.lora_config is not None:
|
||||
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
||||
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.uses_mrope = model_config.uses_mrope
|
||||
@ -591,6 +595,17 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
||||
logits_indices = logits_indices.to(self.device)
|
||||
|
||||
if self.lora_config is not None:
|
||||
# We need to respect padding when activating LoRA adapters
|
||||
padded_num_scheduled_tokens_per_req = np.copy(
|
||||
num_scheduled_tokens_per_req
|
||||
) # Copying to avoid accidental state corruption bugs
|
||||
padded_num_scheduled_tokens_per_req[-1] += \
|
||||
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
|
||||
|
||||
self.set_active_loras(self.input_batch,
|
||||
padded_num_scheduled_tokens_per_req)
|
||||
|
||||
layer_names = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention).keys()
|
||||
per_layer_attn_metadata = {
|
||||
@ -916,6 +931,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
model = self.load_lora_model(model, self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config, self.device)
|
||||
replace_set_lora(model)
|
||||
|
||||
# Sync all pending XLA execution during model initialization and weight
|
||||
# loading.
|
||||
@ -980,7 +996,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with self.maybe_dummy_run_with_lora(
|
||||
with self.maybe_select_dummy_loras(
|
||||
self.lora_config,
|
||||
np.array([num_tokens], dtype=np.int32)), set_forward_context(
|
||||
per_layer_attn_metadata, self.vllm_config, 0):
|
||||
@ -989,6 +1005,13 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
inputs_embeds=inputs_embeds)
|
||||
self._hidden_states_dtype = out.dtype
|
||||
|
||||
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests) -> None:
|
||||
xm.mark_step() # Captures input updates
|
||||
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests)
|
||||
xm.mark_step() # Captures metadata updates
|
||||
|
||||
def _precompile_mm_encoder(self) -> None:
|
||||
# Pre-compile MM encoder for all supported data modalities.
|
||||
hf_config = self.vllm_config.model_config.hf_config
|
||||
@ -1151,7 +1174,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
generate_params_if_all_greedy,
|
||||
))
|
||||
sampling_metadata.all_greedy = all_greedy
|
||||
self.sample_from_logits(dummy_logits, sampling_metadata)
|
||||
with self.maybe_select_dummy_loras(
|
||||
self.lora_config, np.array([num_reqs],
|
||||
dtype=np.int32)):
|
||||
self.sample_from_logits(dummy_logits, sampling_metadata)
|
||||
logger.info(" -- num_seqs: %d", num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
@ -1167,7 +1193,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=self._hidden_states_dtype)
|
||||
dummy_tokens = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int64).to(self.device)
|
||||
self.gather_logprobs(dummy_logits, dummy_tokens)
|
||||
with self.maybe_select_dummy_loras(
|
||||
self.lora_config, np.array([num_reqs], dtype=np.int32)):
|
||||
self.gather_logprobs(dummy_logits, dummy_tokens)
|
||||
logger.info(" -- num_seqs: %d", num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
@ -1178,13 +1206,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
"""
|
||||
Precompile all the subgraphs with possible input shapes.
|
||||
"""
|
||||
self._precompile_mm_encoder()
|
||||
self._precompile_backbone()
|
||||
self._precompile_select_hidden_states()
|
||||
self._precompile_compute_logits()
|
||||
self._precompile_structured_decoding()
|
||||
self._precompile_sample_from_logits()
|
||||
self._precompile_gather_logprobs()
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
self._precompile_mm_encoder()
|
||||
self._precompile_backbone()
|
||||
self._precompile_select_hidden_states()
|
||||
self._precompile_compute_logits()
|
||||
self._precompile_structured_decoding()
|
||||
self._precompile_sample_from_logits()
|
||||
self._precompile_gather_logprobs()
|
||||
|
||||
def profile_run(
|
||||
self,
|
||||
@ -1467,11 +1496,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int,
|
||||
padding_gap: int) -> list[int]:
|
||||
"""Generate a list of padding size, starting from min_token_size,
|
||||
ending with a number that can cover max_token_size
|
||||
|
||||
|
||||
If padding_gap == 0 then:
|
||||
increase 2X each time (exponential)
|
||||
else:
|
||||
first increase the size to twice,
|
||||
first increase the size to twice,
|
||||
then increase the padding size by padding_gap.
|
||||
"""
|
||||
# assert min_token_size is power of 2
|
||||
@ -1508,3 +1537,32 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
||||
index = bisect.bisect_left(paddings, x)
|
||||
assert index < len(paddings)
|
||||
return paddings[index]
|
||||
|
||||
|
||||
def replace_set_lora(model):
|
||||
|
||||
def _tpu_set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# TODO: The integer index leads to a recompilation, but converting it
|
||||
# to a tensor doesn't seem to work anymore. This might be fixed with a
|
||||
# later release of torch_xla.
|
||||
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
|
||||
xm.mark_step()
|
||||
|
||||
def _tpu_reset_lora(self, index: int):
|
||||
self._original_reset_lora(index)
|
||||
xm.mark_step()
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, BaseLayerWithLoRA):
|
||||
module._original_set_lora = module.set_lora
|
||||
module._original_reset_lora = module.reset_lora
|
||||
module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
|
||||
module.reset_lora = _tpu_reset_lora.__get__(
|
||||
module, module.__class__)
|
||||
|
||||
@ -83,10 +83,6 @@ class TPUWorker:
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
raise NotImplementedError(
|
||||
"The V1 TPU backend doesn't support LoRA serving")
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
@ -166,7 +162,8 @@ class TPUWorker:
|
||||
runner_kv_caches)
|
||||
|
||||
# `max_num_tokens >= max_num_batched_tokens` due to padding.
|
||||
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
||||
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
|
||||
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
||||
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user