mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 18:15:27 +08:00
[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend (#15655)
Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: xihajun <junfan@krai.ai> Signed-off-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk> Signed-off-by: Jorge de Freitas <jorge@krai.ai> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: xihajun <junfan@krai.ai> Co-authored-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk> Co-authored-by: Jorge de Freitas <jorge@krai.ai>
This commit is contained in:
parent
a09c7ca9f2
commit
643622ba46
@ -122,10 +122,8 @@ run_and_track_test 11 "test_struct_output_generate.py" \
|
|||||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py"
|
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py"
|
||||||
run_and_track_test 12 "test_moe_pallas.py" \
|
run_and_track_test 12 "test_moe_pallas.py" \
|
||||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
|
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
|
||||||
|
run_and_track_test 13 "test_lora.py" \
|
||||||
# Disable the TPU LoRA tests until the feature is activated
|
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
|
||||||
# run_and_track_test 13 "test_lora (directory)" \
|
|
||||||
# "python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/"
|
|
||||||
|
|
||||||
# After all tests have been attempted, exit with the overall status.
|
# After all tests have been attempted, exit with the overall status.
|
||||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||||
|
|||||||
@ -1,73 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Required to register the custom ops
|
|
||||||
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
|
|
||||||
|
|
||||||
N_TOKENS = [16, 1024, 4096]
|
|
||||||
HIDDEN_SIZES = [1024, 2048, 4096]
|
|
||||||
|
|
||||||
DTYPES = [torch.bfloat16]
|
|
||||||
NUM_LORA = [1, 4, 16]
|
|
||||||
RANKS = [32, 256, 512]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
|
|
||||||
"""
|
|
||||||
Inputs: (All integers)
|
|
||||||
T: Total number of tokens
|
|
||||||
D: Input dim
|
|
||||||
L: LoRA Dim
|
|
||||||
N: N LoRAs
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
inputs: torch.Tensor - shape (T, D)
|
|
||||||
loras: torch.Tensor - shape (N, 1, L, D)
|
|
||||||
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
|
|
||||||
|
|
||||||
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
|
|
||||||
"""
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
inputs = torch.randn((T, D), device="xla", dtype=dtype)
|
|
||||||
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype)
|
|
||||||
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla")
|
|
||||||
|
|
||||||
ref_output = ref_bgmv(inputs, loras, idxs)
|
|
||||||
return inputs, loras, idxs, ref_output
|
|
||||||
|
|
||||||
|
|
||||||
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
|
|
||||||
selected_loras = loras[idxs]
|
|
||||||
if len(selected_loras.shape) == 4:
|
|
||||||
selected_loras = selected_loras.squeeze(axis=1)
|
|
||||||
|
|
||||||
batch_size, output_size, input_size = selected_loras.shape
|
|
||||||
return (selected_loras @ inputs.reshape(
|
|
||||||
(batch_size, input_size, 1))).reshape((batch_size, output_size))
|
|
||||||
|
|
||||||
|
|
||||||
# Parameterize tests with various shapes and dtypes
|
|
||||||
@pytest.mark.parametrize("T", N_TOKENS)
|
|
||||||
@pytest.mark.parametrize("D", HIDDEN_SIZES)
|
|
||||||
@pytest.mark.parametrize("L", RANKS)
|
|
||||||
@pytest.mark.parametrize("N", NUM_LORA)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
|
||||||
@pytest.mark.parametrize("seed", [0])
|
|
||||||
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
|
|
||||||
if op_type == "expand":
|
|
||||||
D, L = L, D
|
|
||||||
|
|
||||||
inputs, loras, idxs, ref_output = generate_test_data(
|
|
||||||
T, D, L, N, seed, dtype)
|
|
||||||
|
|
||||||
# Run bgmv
|
|
||||||
output = torch.ops.xla.bgmv(inputs, loras, idxs)
|
|
||||||
|
|
||||||
# Make sure we have no NaNs
|
|
||||||
assert not torch.any(torch.isnan(output))
|
|
||||||
|
|
||||||
# Compare with reference output
|
|
||||||
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)
|
|
||||||
@ -1,16 +1,54 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
# Required to register the custom ops
|
import torch_xla.core.xla_builder as xb
|
||||||
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
|
from torch.library import impl
|
||||||
|
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
|
||||||
|
|
||||||
|
|
||||||
def bgmv_expand(inputs: torch.Tensor,
|
@jax.jit
|
||||||
|
def bgmv_jax(inputs, loras, idxs):
|
||||||
|
return jnp.einsum(
|
||||||
|
"td,tX,Xld->tl",
|
||||||
|
inputs,
|
||||||
|
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
|
||||||
|
loras,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
|
||||||
|
|
||||||
|
|
||||||
|
@impl(XLA_LIB, "bgmv", "XLA")
|
||||||
|
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
||||||
|
if len(loras.shape) == 4:
|
||||||
|
loras = loras.squeeze(axis=1)
|
||||||
|
|
||||||
|
jax_import_guard()
|
||||||
|
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
|
||||||
|
|
||||||
|
|
||||||
|
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
|
||||||
|
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
|
||||||
|
idxs: torch.IntTensor):
|
||||||
|
T, _ = inputs.shape
|
||||||
|
if len(loras.shape) == 4:
|
||||||
|
loras = loras.squeeze(axis=1)
|
||||||
|
_, L, _ = loras.shape
|
||||||
|
|
||||||
|
return torch.empty((T, L), device=inputs.device)
|
||||||
|
|
||||||
|
|
||||||
|
def bgmv_expand(
|
||||||
|
inputs: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
add_inputs: bool = True):
|
add_inputs: bool = True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||||
@ -28,29 +66,27 @@ def bgmv_expand(inputs: torch.Tensor,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||||
n_tokens = outputs.size(0)
|
|
||||||
|
|
||||||
limit = output_tensor.shape[0]
|
limit = output_tensor.shape[0]
|
||||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||||
limit = 1
|
limit = 1
|
||||||
|
|
||||||
outputs = torch.cat(
|
if output_tensor.shape[1] > outputs.shape[1]:
|
||||||
(outputs,
|
outputs = F.pad(outputs,
|
||||||
torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]),
|
(0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
|
||||||
device=outputs.device)),
|
|
||||||
dim=1)
|
|
||||||
|
|
||||||
if add_inputs:
|
if add_inputs:
|
||||||
return output_tensor + outputs[:limit, :]
|
return output_tensor + outputs[:limit, :output_tensor.shape[1]]
|
||||||
else:
|
else:
|
||||||
return outputs[:limit, :]
|
return outputs[:limit, :output_tensor.shape[1]]
|
||||||
|
|
||||||
|
|
||||||
def bgmv_shrink(inputs: torch.Tensor,
|
def bgmv_shrink(
|
||||||
|
inputs: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
|
||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
scaling: float = 1.0):
|
scaling: float = 1.0,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||||
@ -66,13 +102,15 @@ def bgmv_shrink(inputs: torch.Tensor,
|
|||||||
lora_indices_tensor)
|
lora_indices_tensor)
|
||||||
|
|
||||||
|
|
||||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
def bgmv_expand_slice(
|
||||||
|
inputs: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
slice_offset: int,
|
slice_offset: int,
|
||||||
slice_size: int,
|
slice_size: int,
|
||||||
add_inputs: bool = True):
|
add_inputs: bool = True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||||
@ -89,16 +127,16 @@ def bgmv_expand_slice(inputs: torch.Tensor,
|
|||||||
tensor.
|
tensor.
|
||||||
"""
|
"""
|
||||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||||
n_tokens = outputs.size(0)
|
|
||||||
|
|
||||||
outputs = torch.cat((
|
outputs = F.pad(
|
||||||
torch.zeros((n_tokens, slice_offset), device=outputs.device),
|
|
||||||
outputs,
|
outputs,
|
||||||
torch.zeros(
|
(
|
||||||
(n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)),
|
slice_offset,
|
||||||
device=outputs.device),
|
output_tensor.shape[1] - (slice_offset + slice_size),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
),
|
),
|
||||||
dim=1)
|
)
|
||||||
|
|
||||||
if add_inputs:
|
if add_inputs:
|
||||||
return output_tensor + outputs
|
return output_tensor + outputs
|
||||||
|
|||||||
@ -1,133 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
import functools
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import torch
|
|
||||||
from jax.experimental import pallas as pl
|
|
||||||
from jax.experimental.pallas import tpu as pltpu
|
|
||||||
from torch.library import impl
|
|
||||||
from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard,
|
|
||||||
make_kernel_from_pallas)
|
|
||||||
|
|
||||||
# TODO: Tune these
|
|
||||||
TOKENS_BLOCK = 16
|
|
||||||
LORA_RANK_BLOCK = 128
|
|
||||||
DIM_BLOCK_SIZE = 128
|
|
||||||
|
|
||||||
|
|
||||||
def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref,
|
|
||||||
acc_ref, mask_ref):
|
|
||||||
|
|
||||||
@pl.when(pl.program_id(2) == 0)
|
|
||||||
def _():
|
|
||||||
acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32)
|
|
||||||
|
|
||||||
t = pl.program_id(0)
|
|
||||||
|
|
||||||
for i in range(bT):
|
|
||||||
idx = idx_ref[i + bT * t]
|
|
||||||
mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32)
|
|
||||||
mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32)
|
|
||||||
|
|
||||||
acc_ref[...] += jax.lax.dot_general(
|
|
||||||
inp_ref[...],
|
|
||||||
lora_ref[idx, ...], (((1, ), (1, )), ((), ())),
|
|
||||||
preferred_element_type=jnp.float32) * mask_ref[...]
|
|
||||||
|
|
||||||
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
|
|
||||||
def _():
|
|
||||||
out_ref[...] = acc_ref[...].astype(out_ref.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def _bgmv(
|
|
||||||
idxs: jax.Array, # (T, ) int32
|
|
||||||
inputs: jax.Array, # (T, D) model dtype
|
|
||||||
loras: jax.Array # (N, L, D) model dtype
|
|
||||||
) -> jax.Array: # (T, L) model dtype
|
|
||||||
T, D = inputs.shape
|
|
||||||
N, L, _ = loras.shape
|
|
||||||
|
|
||||||
return pl.pallas_call(
|
|
||||||
kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK),
|
|
||||||
out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype),
|
|
||||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
||||||
num_scalar_prefetch=1,
|
|
||||||
grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK,
|
|
||||||
D // DIM_BLOCK_SIZE),
|
|
||||||
in_specs=[
|
|
||||||
pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE),
|
|
||||||
lambda i, j, k, block_idx: (i, k)),
|
|
||||||
pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE),
|
|
||||||
lambda i, j, k, block_idx: (0, j, k)),
|
|
||||||
],
|
|
||||||
out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK),
|
|
||||||
lambda i, j, k, block_idx: (i, j)),
|
|
||||||
scratch_shapes=[
|
|
||||||
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32),
|
|
||||||
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32)
|
|
||||||
]),
|
|
||||||
compiler_params=pltpu.TPUCompilerParams(
|
|
||||||
dimension_semantics=("parallel", "parallel", "arbitrary")),
|
|
||||||
name="bgmv")(idxs, inputs, loras)
|
|
||||||
|
|
||||||
|
|
||||||
def bgmv_shape_function(idxs, inputs, loras):
|
|
||||||
T, _ = inputs.shape
|
|
||||||
_, L, _ = loras.shape
|
|
||||||
|
|
||||||
return [((T, L), inputs.dtype)]
|
|
||||||
|
|
||||||
|
|
||||||
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", )
|
|
||||||
|
|
||||||
|
|
||||||
@impl(XLA_LIB, "bgmv", "XLA")
|
|
||||||
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
|
||||||
inputs = inputs.to(dtype=loras.dtype)
|
|
||||||
|
|
||||||
if len(loras.shape) == 4:
|
|
||||||
loras = loras.squeeze(axis=1)
|
|
||||||
|
|
||||||
jax_import_guard()
|
|
||||||
kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function)
|
|
||||||
|
|
||||||
T, _ = inputs.shape
|
|
||||||
_, L, D = loras.shape
|
|
||||||
|
|
||||||
# Pad the loras' rank if it's too low. This is to allow it to fit in a TPU
|
|
||||||
# register. This has to happen in pytorch, doing it in Jax will lead to NaNs
|
|
||||||
L1 = L
|
|
||||||
if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0:
|
|
||||||
L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK
|
|
||||||
|
|
||||||
D1 = D
|
|
||||||
if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0:
|
|
||||||
D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE
|
|
||||||
|
|
||||||
T1 = T
|
|
||||||
if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0:
|
|
||||||
T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK
|
|
||||||
|
|
||||||
if D1 != D or L1 != L:
|
|
||||||
loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0))
|
|
||||||
if D1 != D or T1 != T:
|
|
||||||
inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T))
|
|
||||||
if T1 != T:
|
|
||||||
idxs = torch.nn.functional.pad(idxs, ((0, T1 - T)))
|
|
||||||
|
|
||||||
return kernel(idxs, inputs, loras)[:T, :L]
|
|
||||||
|
|
||||||
|
|
||||||
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
|
|
||||||
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
|
|
||||||
idxs: torch.IntTensor):
|
|
||||||
T, _ = inputs.shape
|
|
||||||
|
|
||||||
if len(loras.shape) == 4:
|
|
||||||
loras = loras.squeeze(axis=1)
|
|
||||||
|
|
||||||
_, L, _ = loras.shape
|
|
||||||
|
|
||||||
return torch.empty((T, L), device=inputs.device)
|
|
||||||
@ -1,11 +1,19 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Optional, Union
|
import math
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||||
|
from vllm.lora.punica_wrapper.utils import convert_mapping
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# avoid circuit import
|
||||||
|
from vllm.lora.layers import LoRAMapping
|
||||||
|
from vllm.lora.models import LongContextLoRAContext
|
||||||
|
|
||||||
from .punica_base import PunicaWrapperBase
|
from .punica_base import PunicaWrapperBase
|
||||||
|
|
||||||
@ -31,6 +39,15 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
self._sampler_indices_padded = self._sampler_indices_padded.to(
|
self._sampler_indices_padded = self._sampler_indices_padded.to(
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded,
|
||||||
|
True)
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True)
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch,
|
||||||
|
True)
|
||||||
|
|
||||||
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
|
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
|
||||||
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
|
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
|
||||||
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
|
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
|
||||||
@ -55,15 +72,11 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
|
|
||||||
def shrink(
|
def shrink(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
):
|
):
|
||||||
if self.no_lora:
|
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
|
||||||
return y
|
|
||||||
return bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x),
|
|
||||||
scale)
|
|
||||||
|
|
||||||
def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor,
|
def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor,
|
||||||
add_inputs: bool):
|
add_inputs: bool):
|
||||||
@ -72,7 +85,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
|
|
||||||
def expand_slice(self, y: torch.Tensor, x: torch.Tensor,
|
def expand_slice(self, y: torch.Tensor, x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor, y_offset: int, y_slice_size: int,
|
w_t_all: torch.Tensor, y_offset: int, y_slice_size: int,
|
||||||
y_total_size: int, add_inputs: bool) -> torch.Tensor:
|
add_inputs: bool) -> torch.Tensor:
|
||||||
return bgmv_expand_slice(x, w_t_all, y,
|
return bgmv_expand_slice(x, w_t_all, y,
|
||||||
self._get_token_lora_indices(x), y_offset,
|
self._get_token_lora_indices(x), y_offset,
|
||||||
y_slice_size, add_inputs)
|
y_slice_size, add_inputs)
|
||||||
@ -98,9 +111,8 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
for slice_idx in range(len(lora_a_stacked)):
|
for slice_idx in range(len(lora_a_stacked)):
|
||||||
y_s = y[slice_idx]
|
|
||||||
lora_s = lora_a_stacked[slice_idx]
|
lora_s = lora_a_stacked[slice_idx]
|
||||||
y_s = self.shrink(y_s, x, lora_s, scale)
|
y_s = self.shrink(x, lora_s, scale)
|
||||||
y[slice_idx, :, :] = y_s # type: ignore[index]
|
y[slice_idx, :, :] = y_s # type: ignore[index]
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@ -140,15 +152,12 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
y = self._apply_bias(self._get_token_lora_indices(y), y,
|
y = self._apply_bias(self._get_token_lora_indices(y), y,
|
||||||
output_slices, lora_bias_stacked)
|
output_slices, lora_bias_stacked)
|
||||||
for slice_idx in range(len(lora_b_stacked)):
|
for slice_idx in range(len(lora_b_stacked)):
|
||||||
y = self.expand_slice(
|
y = self.expand_slice(y,
|
||||||
y,
|
|
||||||
x[slice_idx],
|
x[slice_idx],
|
||||||
lora_b_stacked[slice_idx],
|
lora_b_stacked[slice_idx],
|
||||||
offset_left,
|
offset_left,
|
||||||
output_slices[slice_idx],
|
output_slices[slice_idx],
|
||||||
y_total_size=sum(output_slices),
|
add_inputs=add_inputs)
|
||||||
add_inputs=add_inputs,
|
|
||||||
)
|
|
||||||
offset_left += output_slices[slice_idx]
|
offset_left += output_slices[slice_idx]
|
||||||
return y.view_as(y_org)
|
return y.view_as(y_org)
|
||||||
|
|
||||||
@ -216,12 +225,10 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
|
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
r = lora_b_stacked[0].size(-1)
|
r = lora_b_stacked[0].size(-1)
|
||||||
# We set the buffer to be float32 by default, consistent with the
|
|
||||||
# triton op
|
|
||||||
T = x.size(0)
|
T = x.size(0)
|
||||||
buffer = torch.zeros(
|
buffer = torch.zeros(
|
||||||
(len(output_slices), T, r),
|
(len(output_slices), T, r),
|
||||||
dtype=torch.float32,
|
dtype=x.dtype,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||||
@ -257,26 +264,16 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
scale (float): Scaling factor.
|
scale (float): Scaling factor.
|
||||||
buffer (Optional[torch.Tensor]):Default to None.
|
buffer (Optional[torch.Tensor]):Default to None.
|
||||||
"""
|
"""
|
||||||
if self.no_lora:
|
|
||||||
return y
|
|
||||||
|
|
||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
r = lora_b_stacked.size(-1)
|
|
||||||
if buffer is None:
|
|
||||||
# We set the buffer to be float32 by default, consistent with the
|
|
||||||
# triton op
|
|
||||||
buffer = torch.zeros((x.size(0), r),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=x.device)
|
|
||||||
|
|
||||||
buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices,
|
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
|
||||||
scale)
|
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
|
||||||
y = bgmv_expand(buffer,
|
y = bgmv_expand(buffer,
|
||||||
lora_b_stacked,
|
lora_b_stacked,
|
||||||
y,
|
y,
|
||||||
self.sampler_indices,
|
sampler_indices,
|
||||||
add_inputs=True)
|
add_inputs=True)
|
||||||
return y.view_as(y_org)
|
return y.view_as(y_org)
|
||||||
|
|
||||||
@ -316,10 +313,92 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
|
|
||||||
return output.view_as(org_output)
|
return output.view_as(org_output)
|
||||||
|
|
||||||
|
# This performs the same tensor ops as the base method, except it does them
|
||||||
|
# on the CPU then transfers the results to the TPU
|
||||||
|
def _update_base_metadata(
|
||||||
|
self,
|
||||||
|
mapping: "LoRAMapping",
|
||||||
|
lora_index_to_id: list[Optional[int]],
|
||||||
|
max_loras: int,
|
||||||
|
vocab_size: int,
|
||||||
|
extra_vocab_size: int,
|
||||||
|
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||||
|
):
|
||||||
|
# Make sure we don't accidentally collect outside operations
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
||||||
|
# TODO: Should this happen inside mapping internally? If so how can we
|
||||||
|
# avoid having backend specific LoRAMapping classes?
|
||||||
|
mapping.prompt_mapping = self._pad_prompt_mapping(
|
||||||
|
mapping.prompt_mapping)
|
||||||
|
|
||||||
|
(
|
||||||
|
base_indices,
|
||||||
|
sampler_indices,
|
||||||
|
sampler_indices_padded,
|
||||||
|
embeddings_indices,
|
||||||
|
long_lora_offsets_tensor,
|
||||||
|
indices_len,
|
||||||
|
) = convert_mapping(
|
||||||
|
mapping,
|
||||||
|
lora_index_to_id,
|
||||||
|
max_loras,
|
||||||
|
vocab_size,
|
||||||
|
extra_vocab_size,
|
||||||
|
"cpu",
|
||||||
|
long_lora_context,
|
||||||
|
)
|
||||||
|
self._token_lora_indices = self._pad_to_shape(
|
||||||
|
base_indices, self._token_lora_indices.shape,
|
||||||
|
dims=1).to(self.device)
|
||||||
|
self._sampler_indices = self._pad_to_shape(sampler_indices,
|
||||||
|
self._sampler_indices.shape,
|
||||||
|
dims=1).to(self.device)
|
||||||
|
self._sampler_indices_padded = self._pad_to_shape(
|
||||||
|
sampler_indices_padded, self._sampler_indices_padded.shape,
|
||||||
|
dims=1).to(self.device)
|
||||||
|
self._embeddings_indices = self._pad_to_shape(
|
||||||
|
embeddings_indices, self._embeddings_indices.shape,
|
||||||
|
dims=2).to(self.device)
|
||||||
|
if long_lora_offsets_tensor is not None:
|
||||||
|
self._long_lora_indices = self._pad_to_shape(
|
||||||
|
long_lora_offsets_tensor,
|
||||||
|
self._long_lora_indices.shape,
|
||||||
|
dims=1).to(self.device)
|
||||||
|
else:
|
||||||
|
zeroed = torch.zeros_like(self._long_lora_indices.cpu(),
|
||||||
|
dtype=torch.int32)
|
||||||
|
self._long_lora_indices = zeroed.to(self.device)
|
||||||
|
self.indices_len[:] = indices_len
|
||||||
|
|
||||||
def _update_prefill_metadata(self,
|
def _update_prefill_metadata(self,
|
||||||
token_lora_tensor: torch.Tensor) -> None:
|
token_lora_tensor: torch.Tensor) -> None:
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
self._lora_indices_per_batch[:self.batch_size].copy_(
|
self._lora_indices_per_batch[:self.
|
||||||
token_lora_tensor[:self.batch_size])
|
batch_size] = token_lora_tensor[:self.
|
||||||
# TODO: .item() is extremely inefficient on TPU, so find a way around it
|
batch_size]
|
||||||
self.no_lora = torch.all(token_lora_tensor == -1).item()
|
|
||||||
|
def _pad_prompt_mapping(
|
||||||
|
self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
|
||||||
|
num_reqs = len(prompt_mapping)
|
||||||
|
|
||||||
|
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
|
||||||
|
# import
|
||||||
|
MIN_NUM_SEQS = 8
|
||||||
|
|
||||||
|
padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
|
||||||
|
pad_len = padded_num_reqs - num_reqs
|
||||||
|
|
||||||
|
padding = [-1] * pad_len
|
||||||
|
return tuple(list(prompt_mapping) + padding)
|
||||||
|
|
||||||
|
def _pad_to_shape(self, src, target_shape, dims=1):
|
||||||
|
if dims == 1:
|
||||||
|
pad_len = target_shape[0] - src.shape[0]
|
||||||
|
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
|
||||||
|
else:
|
||||||
|
pad_rows = target_shape[0] - src.shape[0]
|
||||||
|
pad_cols = target_shape[1] - src.shape[1]
|
||||||
|
return F.pad(src, (0, pad_cols, 0, pad_rows),
|
||||||
|
value=0).to(torch.int32)
|
||||||
|
|||||||
@ -80,7 +80,37 @@ class LoRAModelRunnerMixin:
|
|||||||
lora_requests)
|
lora_requests)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
def maybe_setup_dummy_loras(self, lora_config):
|
||||||
|
if lora_config is None:
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
# __enter__ code
|
||||||
|
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||||
|
|
||||||
|
num_loras = lora_config.max_loras
|
||||||
|
|
||||||
|
# Make dummy lora requests
|
||||||
|
lora_requests: set[LoRARequest] = {
|
||||||
|
LoRARequest(lora_name=f"warmup_{lora_id}",
|
||||||
|
lora_int_id=lora_id,
|
||||||
|
lora_path="/not/a/real/path")
|
||||||
|
for lora_id in range(1, num_loras + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.lora_manager.dummy_lora_cache():
|
||||||
|
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||||
|
# load from disk.
|
||||||
|
for lr in lora_requests:
|
||||||
|
self.lora_manager.add_dummy_lora(
|
||||||
|
lr, rank=self.LORA_WARMUP_RANK)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# __exit__ code
|
||||||
|
self.lora_manager.remove_all_adapters()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
|
||||||
num_scheduled_tokens: np.ndarray):
|
num_scheduled_tokens: np.ndarray):
|
||||||
if lora_config is None:
|
if lora_config is None:
|
||||||
yield
|
yield
|
||||||
@ -108,21 +138,18 @@ class LoRAModelRunnerMixin:
|
|||||||
for lora_id in range(1, num_loras + 1)
|
for lora_id in range(1, num_loras + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.lora_manager.dummy_lora_cache():
|
|
||||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
|
||||||
# load from disk.
|
|
||||||
for lr in lora_requests:
|
|
||||||
self.lora_manager.add_dummy_lora(
|
|
||||||
lr, rank=self.LORA_WARMUP_RANK)
|
|
||||||
|
|
||||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||||
tuple(token_lora_mapping),
|
tuple(token_lora_mapping), lora_requests)
|
||||||
lora_requests)
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# __exit__ code
|
@contextmanager
|
||||||
self.lora_manager.remove_all_adapters()
|
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||||
|
num_scheduled_tokens: np.ndarray):
|
||||||
|
with self.maybe_setup_dummy_loras(
|
||||||
|
lora_config), self.maybe_select_dummy_loras(
|
||||||
|
lora_config, num_scheduled_tokens):
|
||||||
|
yield
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
if not self.lora_manager:
|
if not self.lora_manager:
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
|||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.layers import BaseLayerWithLoRA
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||||
@ -152,6 +153,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.hidden_size = model_config.get_hidden_size()
|
self.hidden_size = model_config.get_hidden_size()
|
||||||
self.vocab_size = model_config.get_vocab_size()
|
self.vocab_size = model_config.get_vocab_size()
|
||||||
|
|
||||||
|
if self.lora_config is not None:
|
||||||
|
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
||||||
|
|
||||||
# Multi-modal data support
|
# Multi-modal data support
|
||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
self.uses_mrope = model_config.uses_mrope
|
self.uses_mrope = model_config.uses_mrope
|
||||||
@ -591,6 +595,17 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
||||||
logits_indices = logits_indices.to(self.device)
|
logits_indices = logits_indices.to(self.device)
|
||||||
|
|
||||||
|
if self.lora_config is not None:
|
||||||
|
# We need to respect padding when activating LoRA adapters
|
||||||
|
padded_num_scheduled_tokens_per_req = np.copy(
|
||||||
|
num_scheduled_tokens_per_req
|
||||||
|
) # Copying to avoid accidental state corruption bugs
|
||||||
|
padded_num_scheduled_tokens_per_req[-1] += \
|
||||||
|
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
|
||||||
|
|
||||||
|
self.set_active_loras(self.input_batch,
|
||||||
|
padded_num_scheduled_tokens_per_req)
|
||||||
|
|
||||||
layer_names = get_layers_from_vllm_config(self.vllm_config,
|
layer_names = get_layers_from_vllm_config(self.vllm_config,
|
||||||
Attention).keys()
|
Attention).keys()
|
||||||
per_layer_attn_metadata = {
|
per_layer_attn_metadata = {
|
||||||
@ -916,6 +931,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
model = self.load_lora_model(model, self.model_config,
|
model = self.load_lora_model(model, self.model_config,
|
||||||
self.scheduler_config,
|
self.scheduler_config,
|
||||||
self.lora_config, self.device)
|
self.lora_config, self.device)
|
||||||
|
replace_set_lora(model)
|
||||||
|
|
||||||
# Sync all pending XLA execution during model initialization and weight
|
# Sync all pending XLA execution during model initialization and weight
|
||||||
# loading.
|
# loading.
|
||||||
@ -980,7 +996,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for layer_name in layer_names
|
for layer_name in layer_names
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(
|
with self.maybe_select_dummy_loras(
|
||||||
self.lora_config,
|
self.lora_config,
|
||||||
np.array([num_tokens], dtype=np.int32)), set_forward_context(
|
np.array([num_tokens], dtype=np.int32)), set_forward_context(
|
||||||
per_layer_attn_metadata, self.vllm_config, 0):
|
per_layer_attn_metadata, self.vllm_config, 0):
|
||||||
@ -989,6 +1005,13 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
self._hidden_states_dtype = out.dtype
|
self._hidden_states_dtype = out.dtype
|
||||||
|
|
||||||
|
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
|
||||||
|
lora_requests) -> None:
|
||||||
|
xm.mark_step() # Captures input updates
|
||||||
|
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||||
|
lora_requests)
|
||||||
|
xm.mark_step() # Captures metadata updates
|
||||||
|
|
||||||
def _precompile_mm_encoder(self) -> None:
|
def _precompile_mm_encoder(self) -> None:
|
||||||
# Pre-compile MM encoder for all supported data modalities.
|
# Pre-compile MM encoder for all supported data modalities.
|
||||||
hf_config = self.vllm_config.model_config.hf_config
|
hf_config = self.vllm_config.model_config.hf_config
|
||||||
@ -1151,6 +1174,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
generate_params_if_all_greedy,
|
generate_params_if_all_greedy,
|
||||||
))
|
))
|
||||||
sampling_metadata.all_greedy = all_greedy
|
sampling_metadata.all_greedy = all_greedy
|
||||||
|
with self.maybe_select_dummy_loras(
|
||||||
|
self.lora_config, np.array([num_reqs],
|
||||||
|
dtype=np.int32)):
|
||||||
self.sample_from_logits(dummy_logits, sampling_metadata)
|
self.sample_from_logits(dummy_logits, sampling_metadata)
|
||||||
logger.info(" -- num_seqs: %d", num_reqs)
|
logger.info(" -- num_seqs: %d", num_reqs)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
@ -1167,6 +1193,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=self._hidden_states_dtype)
|
dtype=self._hidden_states_dtype)
|
||||||
dummy_tokens = torch.zeros((num_reqs, 1),
|
dummy_tokens = torch.zeros((num_reqs, 1),
|
||||||
dtype=torch.int64).to(self.device)
|
dtype=torch.int64).to(self.device)
|
||||||
|
with self.maybe_select_dummy_loras(
|
||||||
|
self.lora_config, np.array([num_reqs], dtype=np.int32)):
|
||||||
self.gather_logprobs(dummy_logits, dummy_tokens)
|
self.gather_logprobs(dummy_logits, dummy_tokens)
|
||||||
logger.info(" -- num_seqs: %d", num_reqs)
|
logger.info(" -- num_seqs: %d", num_reqs)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
@ -1178,6 +1206,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"""
|
"""
|
||||||
Precompile all the subgraphs with possible input shapes.
|
Precompile all the subgraphs with possible input shapes.
|
||||||
"""
|
"""
|
||||||
|
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||||
self._precompile_mm_encoder()
|
self._precompile_mm_encoder()
|
||||||
self._precompile_backbone()
|
self._precompile_backbone()
|
||||||
self._precompile_select_hidden_states()
|
self._precompile_select_hidden_states()
|
||||||
@ -1508,3 +1537,32 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
|||||||
index = bisect.bisect_left(paddings, x)
|
index = bisect.bisect_left(paddings, x)
|
||||||
assert index < len(paddings)
|
assert index < len(paddings)
|
||||||
return paddings[index]
|
return paddings[index]
|
||||||
|
|
||||||
|
|
||||||
|
def replace_set_lora(model):
|
||||||
|
|
||||||
|
def _tpu_set_lora(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
lora_a: torch.Tensor,
|
||||||
|
lora_b: torch.Tensor,
|
||||||
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
# TODO: The integer index leads to a recompilation, but converting it
|
||||||
|
# to a tensor doesn't seem to work anymore. This might be fixed with a
|
||||||
|
# later release of torch_xla.
|
||||||
|
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
def _tpu_reset_lora(self, index: int):
|
||||||
|
self._original_reset_lora(index)
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
if isinstance(module, BaseLayerWithLoRA):
|
||||||
|
module._original_set_lora = module.set_lora
|
||||||
|
module._original_reset_lora = module.reset_lora
|
||||||
|
module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
|
||||||
|
module.reset_lora = _tpu_reset_lora.__get__(
|
||||||
|
module, module.__class__)
|
||||||
|
|||||||
@ -83,10 +83,6 @@ class TPUWorker:
|
|||||||
if self.model_config.seed is None:
|
if self.model_config.seed is None:
|
||||||
self.model_config.seed = 0
|
self.model_config.seed = 0
|
||||||
|
|
||||||
if vllm_config.lora_config is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"The V1 TPU backend doesn't support LoRA serving")
|
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
os.environ["PJRT_DEVICE"] = "TPU"
|
os.environ["PJRT_DEVICE"] = "TPU"
|
||||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||||
@ -166,6 +162,7 @@ class TPUWorker:
|
|||||||
runner_kv_caches)
|
runner_kv_caches)
|
||||||
|
|
||||||
# `max_num_tokens >= max_num_batched_tokens` due to padding.
|
# `max_num_tokens >= max_num_batched_tokens` due to padding.
|
||||||
|
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
|
||||||
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
||||||
|
|
||||||
# Synchronize before measuring the memory usage.
|
# Synchronize before measuring the memory usage.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user