mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 02:47:04 +08:00
Capture multiple cuda graph across various active loras
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
parent
288d67d054
commit
ea1a26d276
@ -46,6 +46,14 @@ class BatchDescriptor(NamedTuple):
|
||||
"""
|
||||
Whether this batch has active LoRA adapters.
|
||||
"""
|
||||
num_active_loras: int = 0
|
||||
"""
|
||||
Number of distinct active LoRA adapters in this batch.
|
||||
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
|
||||
are captured for each num_active_loras value. This allows kernels
|
||||
(like fused_moe_lora) whose grid size depends on num_active_loras
|
||||
to be properly captured.
|
||||
"""
|
||||
|
||||
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
|
||||
"""
|
||||
@ -53,7 +61,11 @@ class BatchDescriptor(NamedTuple):
|
||||
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
|
||||
"""
|
||||
return BatchDescriptor(
|
||||
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
|
||||
self.num_tokens,
|
||||
num_reqs=None,
|
||||
uniform=False,
|
||||
has_lora=self.has_lora,
|
||||
num_active_loras=self.num_active_loras,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,17 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .utils import supports_pdl
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
||||
|
||||
|
||||
@ -413,6 +417,7 @@ def _fused_moe_lora(
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
@ -562,6 +567,7 @@ def _fused_moe_lora_fake(
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
|
||||
@ -140,6 +140,7 @@ def _lora_expand(
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
@ -291,6 +292,7 @@ def _lora_expand_fake(
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
|
||||
@ -28,6 +28,10 @@ class LoRAKernelMeta:
|
||||
# to early exit from inside the lora_expand / lora_shrink torch operation.
|
||||
no_lora_flag_cpu: torch.Tensor
|
||||
|
||||
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
|
||||
# Stored as a Python int to avoid GPU->CPU sync during forward pass
|
||||
num_active_loras: int = 0
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
max_loras: int, max_num_tokens: int, device: torch.device | str
|
||||
@ -73,6 +77,7 @@ class LoRAKernelMeta:
|
||||
self.num_tokens_per_lora.fill_(0)
|
||||
self.lora_token_start_loc.fill_(0)
|
||||
self.no_lora_flag_cpu.fill_(False)
|
||||
self.num_active_loras = 0
|
||||
|
||||
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
||||
"""
|
||||
@ -117,6 +122,9 @@ class LoRAKernelMeta:
|
||||
self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_(
|
||||
num_tokens_per_lora, non_blocking=True
|
||||
)
|
||||
# Store num_active_loras as Python int (excludes -1 which means no LoRA)
|
||||
# Valid LoRA IDs are >= 0, so count those
|
||||
self.num_active_loras = int((lora_ids >= 0).sum().item())
|
||||
|
||||
# lora_token_start_loc
|
||||
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
|
||||
@ -133,6 +141,7 @@ class LoRAKernelMeta:
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
int,
|
||||
]:
|
||||
"""
|
||||
This function returns the kernel metadata required for the current
|
||||
@ -151,4 +160,5 @@ class LoRAKernelMeta:
|
||||
self.lora_token_start_loc,
|
||||
self.active_lora_ids,
|
||||
self.no_lora_flag_cpu,
|
||||
self.num_active_loras,
|
||||
)
|
||||
|
||||
@ -136,6 +136,7 @@ def _lora_shrink(
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
@ -269,6 +270,7 @@ def _lora_shrink_fake(
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@ -333,8 +333,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
(max_loras), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
|
||||
num_tokens
|
||||
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
|
||||
self.token_mapping_meta.meta_args(num_tokens)
|
||||
)
|
||||
|
||||
ops.moe_lora_align_block_size(
|
||||
@ -378,7 +378,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
"""
|
||||
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
||||
"""
|
||||
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
|
||||
(_, _, _, _, lora_ids, _, num_active_loras) = self.token_mapping_meta.meta_args(
|
||||
x.size(0)
|
||||
)
|
||||
fused_moe_lora(
|
||||
y,
|
||||
x,
|
||||
@ -391,6 +393,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
num_active_loras,
|
||||
adapter_enabled,
|
||||
shrink_config.get("BLOCK_SIZE_M", 64),
|
||||
shrink_config.get("BLOCK_SIZE_N", 64),
|
||||
|
||||
@ -57,9 +57,14 @@ class CudagraphDispatcher:
|
||||
)
|
||||
|
||||
self.keys_initialized = False
|
||||
self.specialize_lora_count = False
|
||||
|
||||
def _create_padded_batch_descriptor(
|
||||
self, num_tokens: int, uniform_decode: bool, has_lora: bool
|
||||
self,
|
||||
num_tokens: int,
|
||||
uniform_decode: bool,
|
||||
has_lora: bool,
|
||||
num_active_loras: int = 0,
|
||||
) -> BatchDescriptor:
|
||||
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
uniform_decode_query_len = self.uniform_decode_query_len
|
||||
@ -77,6 +82,7 @@ class CudagraphDispatcher:
|
||||
num_reqs=num_reqs,
|
||||
uniform=uniform_decode,
|
||||
has_lora=has_lora,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
|
||||
def add_cudagraph_key(
|
||||
@ -87,6 +93,37 @@ class CudagraphDispatcher:
|
||||
)
|
||||
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
|
||||
|
||||
def _get_lora_cases(self) -> list[tuple[bool, int]]:
|
||||
"""
|
||||
Returns list of (has_lora, num_active_loras) tuples for graph capture.
|
||||
|
||||
Returns cases for each num_active_loras from 1 to max_loras.
|
||||
If cudagraph_specialize_lora is True, also includes the no-lora case.
|
||||
|
||||
Note: When speculative decoding is enabled, we fall back to capturing
|
||||
only with max_loras to avoid conflicts with torch.compile during
|
||||
CUDA graph capture.
|
||||
"""
|
||||
if not self.vllm_config.lora_config:
|
||||
return [(False, 0)]
|
||||
|
||||
max_loras = self.vllm_config.lora_config.max_loras
|
||||
|
||||
# When speculative decoding is enabled, only capture with max_loras
|
||||
# to avoid torch.compile conflicts during CUDA graph capture
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
lora_cases = [(True, max_loras)]
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases.append((False, 0))
|
||||
return lora_cases
|
||||
|
||||
# Capture for each num_active_loras from 1 to max_loras
|
||||
lora_cases = [(True, n) for n in range(1, max_loras + 1)]
|
||||
# Also capture the no-lora case
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases.append((False, 0))
|
||||
return lora_cases
|
||||
|
||||
def initialize_cudagraph_keys(
|
||||
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
|
||||
):
|
||||
@ -94,26 +131,23 @@ class CudagraphDispatcher:
|
||||
# get the correct cudagraph mode after backend support is resolved.
|
||||
self.cudagraph_mode = cudagraph_mode
|
||||
|
||||
# LoRA activation cases to specialize the cuda graphs on
|
||||
if self.vllm_config.lora_config:
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases = [True, False]
|
||||
else:
|
||||
lora_cases = [True]
|
||||
else:
|
||||
lora_cases = [False]
|
||||
# Track whether we have LoRA config (always specialize on count)
|
||||
self.has_lora_config = self.vllm_config.lora_config is not None
|
||||
|
||||
# Get LoRA cases to capture
|
||||
lora_cases = self._get_lora_cases()
|
||||
|
||||
# Note: we create all valid keys for cudagraph here but do not
|
||||
# guarantee all keys would be used. For example, if we allow lazy
|
||||
# capturing in future PR, some keys may never be triggered.
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
for bs, has_lora in product(
|
||||
for bs, (has_lora, num_active_loras) in product(
|
||||
self.compilation_config.cudagraph_capture_sizes, lora_cases
|
||||
):
|
||||
self.add_cudagraph_key(
|
||||
cudagraph_mode.mixed_mode(),
|
||||
self._create_padded_batch_descriptor(
|
||||
bs, False, has_lora
|
||||
bs, False, has_lora, num_active_loras
|
||||
).relax_for_mixed_batch_cudagraphs(),
|
||||
)
|
||||
|
||||
@ -132,10 +166,14 @@ class CudagraphDispatcher:
|
||||
for x in self.compilation_config.cudagraph_capture_sizes
|
||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||
]
|
||||
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
|
||||
for bs, (has_lora, num_active_loras) in product(
|
||||
cudagraph_capture_sizes_for_decode, lora_cases
|
||||
):
|
||||
self.add_cudagraph_key(
|
||||
CUDAGraphMode.FULL,
|
||||
self._create_padded_batch_descriptor(bs, True, has_lora),
|
||||
self._create_padded_batch_descriptor(
|
||||
bs, True, has_lora, num_active_loras
|
||||
),
|
||||
)
|
||||
|
||||
self.keys_initialized = True
|
||||
@ -146,12 +184,20 @@ class CudagraphDispatcher:
|
||||
uniform_decode: bool,
|
||||
has_lora: bool,
|
||||
disable_full: bool = False,
|
||||
num_active_loras: int = 0,
|
||||
) -> tuple[CUDAGraphMode, BatchDescriptor]:
|
||||
"""
|
||||
Given conditions(e.g.,batch descriptor and if using cascade attention),
|
||||
dispatch to a cudagraph runtime mode and the valid batch descriptor.
|
||||
A new batch descriptor is returned as we might dispatch a uniform batch
|
||||
to a graph that supports a more general batch (uniform to non-uniform).
|
||||
|
||||
Args:
|
||||
num_tokens: Number of tokens in the batch.
|
||||
uniform_decode: Whether this is a uniform decode batch.
|
||||
has_lora: Whether this batch has active LoRA adapters.
|
||||
disable_full: Whether to disable full cudagraph mode.
|
||||
num_active_loras: Number of distinct active LoRA adapters.
|
||||
"""
|
||||
if (
|
||||
not self.keys_initialized
|
||||
@ -160,8 +206,18 @@ class CudagraphDispatcher:
|
||||
):
|
||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||
|
||||
# When speculative decoding is enabled, always use max_loras for lookup
|
||||
# since we only capture graphs with max_loras
|
||||
effective_num_active_loras = num_active_loras
|
||||
if (
|
||||
self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.lora_config is not None
|
||||
and has_lora
|
||||
):
|
||||
effective_num_active_loras = self.vllm_config.lora_config.max_loras
|
||||
|
||||
batch_desc = self._create_padded_batch_descriptor(
|
||||
num_tokens, uniform_decode, has_lora
|
||||
num_tokens, uniform_decode, has_lora, effective_num_active_loras
|
||||
)
|
||||
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
|
||||
|
||||
|
||||
@ -2820,6 +2820,7 @@ class GPUModelRunner(
|
||||
# be improved in model runner v2)
|
||||
force_uniform_decode: bool | None = None,
|
||||
force_has_lora: bool | None = None,
|
||||
force_num_active_loras: int | None = None,
|
||||
num_encoder_reqs: int = 0,
|
||||
) -> tuple[
|
||||
CUDAGraphMode,
|
||||
@ -2841,11 +2842,13 @@ class GPUModelRunner(
|
||||
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
|
||||
)
|
||||
|
||||
has_lora = (
|
||||
len(self.input_batch.lora_id_to_lora_request) > 0
|
||||
if force_has_lora is None
|
||||
else force_has_lora
|
||||
# Compute LoRA state for cudagraph dispatch
|
||||
num_active_loras = (
|
||||
force_num_active_loras
|
||||
if force_num_active_loras is not None
|
||||
else len(self.input_batch.lora_id_to_lora_request)
|
||||
)
|
||||
has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora
|
||||
|
||||
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
||||
dispatch_cudagraph = (
|
||||
@ -2854,6 +2857,7 @@ class GPUModelRunner(
|
||||
has_lora=has_lora,
|
||||
uniform_decode=uniform_decode,
|
||||
disable_full=disable_full,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
if not force_eager
|
||||
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
|
||||
@ -4058,6 +4062,7 @@ class GPUModelRunner(
|
||||
remove_lora: bool = True,
|
||||
activate_lora: bool = False,
|
||||
is_graph_capturing: bool = False,
|
||||
num_active_loras: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Run a dummy forward pass to warm up/profile run or capture the
|
||||
@ -4081,6 +4086,8 @@ class GPUModelRunner(
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
activate_lora: If False, dummy_run is performed without LoRAs.
|
||||
num_active_loras: Number of distinct active LoRAs to capture for.
|
||||
Only used when cudagraph_specialize_lora_count is True.
|
||||
"""
|
||||
if supports_mm_encoder_only(self.model):
|
||||
# The current dummy run only covers LM execution, so we can skip it.
|
||||
@ -4162,6 +4169,9 @@ class GPUModelRunner(
|
||||
# activated later in the context manager, but we need to know the
|
||||
# LoRA state when determining the batch descriptor for capture
|
||||
force_has_lora=activate_lora,
|
||||
# `force_num_active_loras` is used for cudagraph capture; because we
|
||||
# need to capture graphs for specific num_active_loras counts
|
||||
force_num_active_loras=num_active_loras if activate_lora else 0,
|
||||
)
|
||||
)
|
||||
|
||||
@ -4225,6 +4235,7 @@ class GPUModelRunner(
|
||||
num_sampled_tokens,
|
||||
activate_lora,
|
||||
remove_lora,
|
||||
num_active_loras,
|
||||
):
|
||||
# Make sure padding doesn't exceed max_num_tokens
|
||||
assert num_tokens_padded <= self.max_num_tokens
|
||||
@ -4587,6 +4598,24 @@ class GPUModelRunner(
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
def _get_lora_capture_cases(self) -> list[tuple[bool, int]]:
|
||||
"""
|
||||
Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture.
|
||||
|
||||
Returns cases for each num_active_loras from 1 to max_loras.
|
||||
If cudagraph_specialize_lora is True, also includes the no-lora case.
|
||||
"""
|
||||
if not self.lora_config:
|
||||
return [(False, 0)]
|
||||
|
||||
max_loras = self.lora_config.max_loras
|
||||
# Capture for each num_active_loras from 1 to max_loras
|
||||
lora_cases = [(True, n) for n in range(1, max_loras + 1)]
|
||||
# Also capture the no-lora case if cudagraph_specialize_lora is True
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases.append((False, 0))
|
||||
return lora_cases
|
||||
|
||||
def capture_model(self) -> int:
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
logger.warning(
|
||||
@ -4624,13 +4653,8 @@ class GPUModelRunner(
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
|
||||
if self.lora_config:
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases = [True, False]
|
||||
else:
|
||||
lora_cases = [True]
|
||||
else:
|
||||
lora_cases = [False]
|
||||
# Build LoRA cases: list of (has_lora, num_active_loras) tuples
|
||||
lora_cases = self._get_lora_capture_cases()
|
||||
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
||||
@ -4695,10 +4719,23 @@ class GPUModelRunner(
|
||||
|
||||
def _capture_cudagraphs(
|
||||
self,
|
||||
compilation_cases: list[tuple[int, bool]],
|
||||
compilation_cases: list[tuple[int, tuple[bool, int]]],
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
uniform_decode: bool,
|
||||
):
|
||||
"""
|
||||
Capture CUDA graphs for the given compilation cases.
|
||||
|
||||
Args:
|
||||
compilation_cases: List of (num_tokens, (has_lora, num_active_loras))
|
||||
tuples.
|
||||
- num_tokens: batch size to capture
|
||||
- has_lora: whether LoRA is active for this capture
|
||||
- num_active_loras: number of distinct active LoRAs
|
||||
(0 if not specializing)
|
||||
cudagraph_runtime_mode: FULL or PIECEWISE cudagraph mode.
|
||||
uniform_decode: Whether this is a uniform decode batch.
|
||||
"""
|
||||
assert (
|
||||
cudagraph_runtime_mode != CUDAGraphMode.NONE
|
||||
and cudagraph_runtime_mode.valid_runtime_modes()
|
||||
@ -4716,7 +4753,7 @@ class GPUModelRunner(
|
||||
)
|
||||
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens, activate_lora in compilation_cases:
|
||||
for num_tokens, (activate_lora, num_active_loras) in compilation_cases:
|
||||
# We currently only capture ubatched graphs when its a FULL
|
||||
# cudagraph, a uniform decode batch, and the number of tokens
|
||||
# is above the threshold. Otherwise we just capture a non-ubatched
|
||||
@ -4748,6 +4785,7 @@ class GPUModelRunner(
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
self._dummy_run(
|
||||
num_tokens,
|
||||
@ -4757,6 +4795,7 @@ class GPUModelRunner(
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
num_active_loras=num_active_loras,
|
||||
is_graph_capturing=True,
|
||||
)
|
||||
self.maybe_remove_all_loras(self.lora_config)
|
||||
|
||||
@ -4,7 +4,10 @@
|
||||
Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -20,7 +23,8 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
|
||||
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
|
||||
|
||||
InputBatch = TPUInputBatch | GPUInputBatch
|
||||
if TYPE_CHECKING:
|
||||
InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -129,7 +133,20 @@ class LoRAModelRunnerMixin:
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_sampled_tokens: np.ndarray | None = None,
|
||||
activate_lora: bool = True,
|
||||
num_active_loras: int = 0,
|
||||
):
|
||||
"""
|
||||
Context manager to select dummy LoRAs for capture/warmup.
|
||||
|
||||
Args:
|
||||
lora_config: LoRA configuration, or None if LoRA is disabled.
|
||||
num_scheduled_tokens: Array of scheduled token counts per request.
|
||||
num_sampled_tokens: Array of sampled token counts per request.
|
||||
activate_lora: Whether to activate LoRAs (False means no LoRA).
|
||||
num_active_loras: Number of distinct active LoRAs to use.
|
||||
- 0: Use all max_loras (default behavior for has_lora=True).
|
||||
- >0: Use exactly this many distinct LoRAs.
|
||||
"""
|
||||
if num_sampled_tokens is None:
|
||||
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
|
||||
|
||||
@ -140,13 +157,24 @@ class LoRAModelRunnerMixin:
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
num_loras = lora_config.max_loras
|
||||
max_loras = lora_config.max_loras
|
||||
|
||||
# Determine how many distinct LoRAs to use
|
||||
if not activate_lora:
|
||||
# No LoRA active
|
||||
effective_num_loras = 0
|
||||
elif num_active_loras > 0:
|
||||
# Specific number of active LoRAs requested
|
||||
effective_num_loras = min(num_active_loras, max_loras)
|
||||
else:
|
||||
# Default: use all max_loras
|
||||
effective_num_loras = max_loras
|
||||
|
||||
# Make prompt lora mapping
|
||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||
if activate_lora:
|
||||
if effective_num_loras > 0:
|
||||
prompt_lora_mapping = (
|
||||
np.arange(num_reqs, dtype=np.int32) % num_loras
|
||||
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
|
||||
) + 1
|
||||
else:
|
||||
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
|
||||
@ -157,19 +185,20 @@ class LoRAModelRunnerMixin:
|
||||
# Make token lora mapping
|
||||
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
||||
|
||||
# Make dummy lora requests
|
||||
# Make dummy lora requests (only for the active LoRAs)
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(
|
||||
lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path",
|
||||
)
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
for lora_id in range(1, effective_num_loras + 1)
|
||||
}
|
||||
|
||||
self._set_active_loras(
|
||||
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
|
||||
)
|
||||
if lora_requests:
|
||||
self._set_active_loras(
|
||||
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
@ -181,11 +210,27 @@ class LoRAModelRunnerMixin:
|
||||
num_sampled_tokens: np.ndarray,
|
||||
activate_lora: bool = True,
|
||||
remove_lora: bool = True,
|
||||
num_active_loras: int = 0,
|
||||
):
|
||||
"""
|
||||
Context manager for dummy runs with LoRA.
|
||||
|
||||
Args:
|
||||
lora_config: LoRA configuration.
|
||||
num_scheduled_tokens: Array of scheduled token counts per request.
|
||||
num_sampled_tokens: Array of sampled token counts per request.
|
||||
activate_lora: Whether to activate LoRAs.
|
||||
remove_lora: Whether to remove LoRAs after the context exits.
|
||||
num_active_loras: Number of distinct active LoRAs to use.
|
||||
"""
|
||||
with (
|
||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||
self.maybe_select_dummy_loras(
|
||||
lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
|
||||
lora_config,
|
||||
num_scheduled_tokens,
|
||||
num_sampled_tokens,
|
||||
activate_lora,
|
||||
num_active_loras,
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user